PyTorch
ナビゲーションに移動
検索に移動
AIをつくろう
やりかた
データセットの準備
https://pytorch.org/docs/stable/data.html
- Datasetを作成
- 適宜、データ拡張
- DataLoaderを作成
モデルの定義
モデルはtorch.nn.Module
の派生クラスである必要がある。
torch.nn.Sequential
や
torch.nn.ModuleList
を使うこともできる。
学習
model.train()
を実行しておく。
各イテレーションにおいて必要な処理は以下の通り。
torch.Optimizer.zero_grad()
またはnn.Module.zero_grad()
[1]で勾配を初期化しておく。- モデルにデータを入力して出力を得る(順伝播)
- 損失関数を計算する。
- 勾配を計算する。
loss.backward()
- 誤差逆伝播
optim.step()
テスト
model.eval()
を実行しておく。
これにより学習時に有効になっていたdropoutやbatch normなどが無効化される。
また、with torch.no_grad():
としておけば勾配の計算が省かれ、速くなる。