G検定 JAX
株式会社リュディアです。今回は JAX についてまとめてみます。以前にまとめたAI周辺のオープンソースソフトウェアに追加すべきでした。
JAXは数値計算や機械学習のために設計された Google 製の Python ライブラリで Autograd と XLA からなります。高速な NumPy のイメージのようです。以下に GitHub にあるソースコードへのリンクをつけておきます。
JAXの重要な構成要素に Autograd、つまり自動微分があります。Python で記述した関数に対する偏導関数を自動的に求める機能です。ディープラーニングを実装する上で非常に重要な機能です。Pytorch にも autograd がありますがさらに高速化されています。
もう1つの重要な構成要素である XLA は線形代数の演算に特化したコンパイラに対応しており @jit を指定されたフォーマットの関数に付加することで実行時に自動的にコンパイルされ高速演算が可能になります。コンパイルといっても JITコンパイル と呼ばれる方式で中間コードを出力し、最終的にそれを GPU やTPU (TensorFlow Processing Unit) 向けの専用コードに変換して実行します。
ちなみに JITコンパイル とは Just-In-Time コンパイルのことです。日本語では「その都度コンパイラ」と言います。Java が発表されたときに話題になった技術です。
では、ごきげんよう。