PyTorch

提供:ペチラボ書庫
ナビゲーションに移動 検索に移動

https://pytorch.org/

AIをつくろう

やりかた

データセットの準備

https://pytorch.org/docs/stable/data.html

  • Datasetを作成
    • 適宜、データ拡張
  • DataLoaderを作成

モデルの定義

モデルはtorch.nn.Moduleの派生クラスである必要がある。

torch.nn.Sequentialtorch.nn.ModuleList を使うこともできる。

学習

model.train()を実行しておく。 各イテレーションにおいて必要な処理は以下の通り。

  1. torch.Optimizer.zero_grad()または nn.Module.zero_grad()[1]で勾配を初期化しておく。
  2. モデルにデータを入力して出力を得る(順伝播)
  3. 損失関数を計算する。
  4. 勾配を計算する。loss.backward()
  5. 誤差逆伝播 optim.step()

テスト

model.eval()を実行しておく。 これにより学習時に有効になっていたdropoutやbatch normなどが無効化される。 また、with torch.no_grad():としておけば勾配の計算が省かれ、速くなる。