見出し画像

高速で解釈性の高い FastSGLモデル (Ver.3.9.6)

時系列データ分析ツール Node-AI スクラムマスターの 中野 です!

今回のアップデートでは、高速で解釈性の高いモデル FastSGL をリリースしています!

詳細な特徴や利用方法についてはマニュアルをご確認ください。

FastSGLは NTTコンピュータ&データサイエンス研究所 が開発した手法で、通常のSGL(Sparse Group Lasso)と比較して学習速度が高速であることが特徴です。

FastSGLは NeurIPS というAI系のトップカンファレンスにも採択されています(論文)!
このようにNTTグループでは世界最先端の研究開発を行い、その成果をプロダクトに実装する活動を進めています!

さて、FastSGLがどのように嬉しいのか検証したのでその結果を共有させていただきます。

Sparse ”Group" Lasso という名称のとおり、FastSGLは特徴量をグループ化する仕組みを持っています。
グループ化によりモデルの解釈性が高まり、「AIモデルがどのように動いているのか」を直感的に理解することが容易になります。
このグループの構造は本来任意で設定できるものですが、Node-AIでは「1カラムが1グループ」として自動で設定されます(将来的に変更される可能性があります)。

言葉で説明するのは大変なので、実例を示しながら解説していきます。

今回デモで利用するのは、公開データにもある「北京PM2.5濃度予測」データです。
目的変数として "pm2.5" 、説明変数として "pm2.5" と "cbwd" 以外を選択します。
※この検証では結果をわかりやすくするために意図的な調整をしています。精度の高いモデルを作ることが目標ではない点をご了承ください。

使用するデータと前処理


時間窓切り出しにより、24時間分のデータを使って6時間後の値を解く問題とします。
これ以降は少し複雑になりますが、以下のようなツリーを作成します。

※このツリーでは要因の簡易的な確認のため、評価用のデータを作成していません。

FastSGL検証用のツリー

モデルは3種類用意します。

  • 線形モデルのLasso(alpha=0.1)

  • FastSGLモデル(alpha=0.999、rho=0.1)

  • FastSGLモデル(alpha=0.8、rho=0.1)

ややこしいのですが、FastSGLにおける "rho" がLassoにおける "alpha" に相当します。そのため、3モデルは alpha= 0.1 のLassoを基本にしていることになります。

FastSGLのalphaはグループ化の強度を設定します。1に近いほど通常のLassoとなるため、alpha=0.999のFastSGLは「ほぼLasso」ということになります。もうひとつのalpha=0.8のFastSGLは「少しグループ化の効果を強くしたLasso」となります。

この3モデルの要因分析の結果を見てみましょう。
ここで、LassoやFastSGLの要因分析結果は「各特徴量の重み」をヒートマップで表示していることにご注意ください。

Lasso (alpha=0.1)の要因分析結果
FastSGL (alpha=0.999, rho=0.1)の要因分析結果

まずこの2モデルの要因分析結果は非常に似ていることがわかります。
(2つ目のモデルは「ほぼLasso」なので当然ですが…)
"DEWP"、"TEMP"、"Iws" といったカラムの中でまばらに重みが付いています。

これこそがLassoの重要な「特徴選択」の効果です。たくさん(例では6カラム×24分=144個)ある特徴量の中から一部の特徴量のみ重みが付いているので、「モデルがどの特徴量に着目しているのか」が人間が見てもわかりやすくなります。

ただ一方で、「なぜその中でも特定の時刻をまばらに選択しているのだろう?」と疑問も湧いてきますよね。
そこで3つ目のモデルの要因分析結果も見てみましょう。

FastSGL (alpha=0.8, rho=0.1) の要因分析結果

その他の2つのモデルよりも「わかりやすい重み」であると思いませんか?

重みが付いているのは "TEMP" と "Iws" だけになっていますし、時刻も飛び飛びではなくスムージングされたように滑らかに変化しています。

これがFastSGL(というよりSGL)の強みになります。
前述の通り、Node-AIでは1カラムが1グループに設定されるので、各カラムをひとつの塊のように捉えたモデルを作成することができます。

※ただし、解釈性が上がるからと言って精度が上がることには繋がりません。実際にこの例でもFastSGLはLassoより精度が低下しています。
※またFastSGLは「元となるSGLが」高速化されたアルゴリズムです。通常のLassoより高速になるというわけではありません(実験的には同等か、少し遅い程度の学習速度になります)。

いかがでしたか?モデルの解釈性が低く困っていたみなさま、ぜひご利用ください!

いいなと思ったら応援しよう!