JAXとはPythonで何ですか?
高速なNumPyとして使いこなすためのチュートリアル~
Google製のライブラリで、AutogradとXLAからなる、機械学習のための数値計算ライブラリ。簡単に言うと「自動微分に特化した、GPUやTPUに対応した高速なNumPy」。NumPyとほとんど同じ感覚で書くことができます。自動微分については解説が多いので、この記事では単なる高速なNumPyの部分を中心に書いていきます。
関連記事
GPU対応のNumPyという観点では、似たライブラリとしてPFN製のCuPyや、AnacondaがスポンサーとなっているNumbaもあります。
配列の初期化
最初はCPUに限定して書きます。JAXの導入はとてもシンプルで、あたかもNumPyのように使うことができます。
import jax.numpy as jnp
# NumPyではnp.arange(25, dtype=np.float32).reshape(5, 5)
x = jnp.arange(25, dtype=jnp.float32).reshape(5, 5)
print(x)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[[ 0. 1. 2. 3. 4.]
[ 5. 6. 7. 8. 9.]
[10. 11. 12. 13. 14.]
[15. 16. 17. 18. 19.]
[20. 21. 22. 23. 24.]]
JAXでのNumPy関数
.block_until_ready()
NumPy関数はnpをjnpに書き換えるだけ。ただし、JAXでは非同期処理で計算されるため、計算の最後に.block_until_ready()を追加します。
# NumPyではnp.dot(x, x.T)
x_gram = jnp.dot(x, x.T).block_until_ready()
print(x_gram)
[[ 30. 80. 130. 180. 230.]
[ 80. 255. 430. 605. 780.]
[ 130. 430. 730. 1030. 1330.]
[ 180. 605. 1030. 1455. 1880.]
[ 230. 780. 1330. 1880. 2430.]]
特に理由がなければ.block_until_ready()はJAXの計算の最後のみ入れればOKです。
y = x + 1
x_gram = jnp.dot(x, y.T).block_until_ready() # 最後だけブロッキングを入れればOK
詳細は下記へ
ref