PyTorchが計算グラフを作成する仕組み
小さな小さなPyTorchクローンを作ろうと思ったので、PyTorchの計算グラフの仕組みを調べました。
公式ブログに分かりやすいGIFアニメがあったので、メモに残します。
1.requires_gradなテンソルを関数に入力
2.grad_fn(勾配関数)ノードが作成される
3.「collect_next_edges」が入力関数からgrad_fnにリンクするエッジを作成
6.同様にして関数Bを作成
7.3つのテンソルを関数Cへ入力
9.「collect_next_edges」が入力の関数を調べて、grad_fnへのエッジを作成する
10.終了。