「PyTorch」の版間の差分
ナビゲーションに移動
検索に移動
編集の要約なし |
(→学習) |
||
22行目: | 22行目: | ||
<syntaxhighlight inline>model.train()</syntaxhighlight>を実行しておく。 | <syntaxhighlight inline>model.train()</syntaxhighlight>を実行しておく。 | ||
各イテレーションにおいて必要な処理は以下の通り。 | 各イテレーションにおいて必要な処理は以下の通り。 | ||
# <syntaxhighlight inline>torch.Optimizer.zero_grad()</syntaxhighlight>または <syntaxhighlight inline>nn.Module.zero_grad()</syntaxhighlight>[https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.zero_grad]で勾配を初期化しておく。 | |||
# モデルにデータを入力して出力を得る(順伝播) | # モデルにデータを入力して出力を得る(順伝播) | ||
# [[損失関数]]を計算する。 | # [[損失関数]]を計算する。 | ||
# 勾配を計算する。<syntaxhighlight inline>loss.backward()</syntaxhighlight> | # 勾配を計算する。<syntaxhighlight inline>loss.backward()</syntaxhighlight> | ||
# 誤差逆伝播 <syntaxhighlight inline>optim.step()</syntaxhighlight> | # 誤差逆伝播 <syntaxhighlight inline>optim.step()</syntaxhighlight> |
2023年2月10日 (金) 01:03時点における最新版
AIをつくろう
やりかた
データセットの準備
https://pytorch.org/docs/stable/data.html
- Datasetを作成
- 適宜、データ拡張
- DataLoaderを作成
モデルの定義
モデルはtorch.nn.Module
の派生クラスである必要がある。
torch.nn.Sequential
や
torch.nn.ModuleList
を使うこともできる。
学習
model.train()
を実行しておく。
各イテレーションにおいて必要な処理は以下の通り。
torch.Optimizer.zero_grad()
またはnn.Module.zero_grad()
[1]で勾配を初期化しておく。- モデルにデータを入力して出力を得る(順伝播)
- 損失関数を計算する。
- 勾配を計算する。
loss.backward()
- 誤差逆伝播
optim.step()
テスト
model.eval()
を実行しておく。
これにより学習時に有効になっていたdropoutやbatch normなどが無効化される。
また、with torch.no_grad():
としておけば勾配の計算が省かれ、速くなる。