「PyTorch」の版間の差分
ナビゲーションに移動
検索に移動
(ページの作成:「https://pytorch.org/」) |
(→学習) |
||
(同じ利用者による、間の1版が非表示) | |||
1行目: | 1行目: | ||
https://pytorch.org/ | https://pytorch.org/ | ||
AIをつくろう | |||
== やりかた == | |||
=== データセットの準備 === | |||
https://pytorch.org/docs/stable/data.html | |||
* Datasetを作成 | |||
** 適宜、データ拡張 | |||
* DataLoaderを作成 | |||
=== モデルの定義 === | |||
モデルは<syntaxhighlight inline>torch.nn.Module</syntaxhighlight>の派生クラスである必要がある。 | |||
<syntaxhighlight inline>torch.nn.Sequential</syntaxhighlight> | |||
や | |||
<syntaxhighlight inline>torch.nn.ModuleList</syntaxhighlight> | |||
を使うこともできる。 | |||
=== 学習 === | |||
<syntaxhighlight inline>model.train()</syntaxhighlight>を実行しておく。 | |||
各イテレーションにおいて必要な処理は以下の通り。 | |||
# <syntaxhighlight inline>torch.Optimizer.zero_grad()</syntaxhighlight>または <syntaxhighlight inline>nn.Module.zero_grad()</syntaxhighlight>[https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.zero_grad]で勾配を初期化しておく。 | |||
# モデルにデータを入力して出力を得る(順伝播) | |||
# [[損失関数]]を計算する。 | |||
# 勾配を計算する。<syntaxhighlight inline>loss.backward()</syntaxhighlight> | |||
# 誤差逆伝播 <syntaxhighlight inline>optim.step()</syntaxhighlight> | |||
=== テスト === | |||
<syntaxhighlight inline>model.eval()</syntaxhighlight>を実行しておく。 | |||
これにより学習時に有効になっていたdropoutやbatch normなどが無効化される。 | |||
また、<syntaxhighlight inline>with torch.no_grad():</syntaxhighlight>としておけば勾配の計算が省かれ、速くなる。 |
2023年2月10日 (金) 01:03時点における最新版
AIをつくろう
やりかた
データセットの準備
https://pytorch.org/docs/stable/data.html
- Datasetを作成
- 適宜、データ拡張
- DataLoaderを作成
モデルの定義
モデルはtorch.nn.Module
の派生クラスである必要がある。
torch.nn.Sequential
や
torch.nn.ModuleList
を使うこともできる。
学習
model.train()
を実行しておく。
各イテレーションにおいて必要な処理は以下の通り。
torch.Optimizer.zero_grad()
またはnn.Module.zero_grad()
[1]で勾配を初期化しておく。- モデルにデータを入力して出力を得る(順伝播)
- 損失関数を計算する。
- 勾配を計算する。
loss.backward()
- 誤差逆伝播
optim.step()
テスト
model.eval()
を実行しておく。
これにより学習時に有効になっていたdropoutやbatch normなどが無効化される。
また、with torch.no_grad():
としておけば勾配の計算が省かれ、速くなる。