見出し画像

dtreevizを使ってみた


dtreevizを使って決定木アルゴリズムを可視化してみます。決定木アルゴリズムとはlightGBMやRandomForestなどを言います。
実験用データとして、下記のポケモンデータを使用します。このポケモンデータのLegendaryを予測するモデルを作成し、その予測をdtreevizで可視化します。

データの確認

モデルの作成を行う前にデータを簡単に確認しておきましょう。

import polars as pl

df = pl.read_csv('Pokemon.csv')
df.head()
df.head()の出力

伝説ポケモンの数を確認します。

df['legendary'].value_counts()
df['legendary'].value_counts()の出力

予測モデルの作成

lightGBMで伝説ポケモンかどうかを予測するモデルを作ることにします。
テストデータと学習データに分けます。
さらに、特徴量として使う列と予測対象の列をセレクトしてlightGBMで扱える形に変換します。

df = df.sample(fraction=1, shuffle=True, seed=43)

test_size = 250
test, train = df.head(test_size), df.tail(-test_size)

feature_cols = ['hp', 'attack', 'defense', 'sp_attack', 'sp_defense', 'speed']
target_col = ['legendary']
X_train = train.select(feature_cols).cast(pl.Float32).to_pandas()
y_train = train.select(target_col).cast(pl.Float32).to_pandas()
X_test = test.select(feature_cols).cast(pl.Float32).to_pandas()
y_test = test.select(target_col).cast(pl.Float32).to_pandas()
dtrain = lgb.Dataset(X_train, y_train)
dvalid = lgb.Dataset(X_test, y_test)

次に、lightGBMのパラメータを設定して学習を実行します。
ここで、予測モデルの解釈性を高めるためにlightGBMのパラメータであるnum_leavesを7に設定しておきます。

params = { 
    'objective': 'binary', # 2値分類
    'random_state': 43,  # 乱数シード,
    'num_leaves': 7,
    'eta': 0.5,
    'verbose': -1,
}

model = lgb.train(
    params = params, 
    train_set = dtrain, # 学習用データセット
    num_boost_round = 100, # Boostingの繰り返し回数
    callbacks=[
        lgb.early_stopping(stopping_rounds=10, verbose=True), # early_stopping用コールバック関数
        lgb.log_evaluation(0)
    ],
    valid_sets = [dvalid] # 評価用データセット
)

モデルができたので、精度を確認しておきます。

from sklearn.metrics import accuracy_score
y_pred = model.predict(X_test)
y_pred = [1.0 if pred > 0.5 else 0.0 for pred in y_pred]
print(f'Accuracy: {accuracy_score(y_test, y_pred)}')

ret = pl.DataFrame({
    'index': [n for n in range(len(y_pred))],
    'GrounTurth': y_test['legendary'].to_list(),
    'Predict': y_pred
})
ret

テスト用データのAccuracyは0.94という結果でした。

dtreevizでモデルを可視化する

可視化は以下のコードを実行する。

import dtreeviz

viz_model = dtreeviz.model(
    model,
    tree_index=1,
    X_train=X_train,
    y_train=y_train[target_col[0]].astype(int),
    feature_names=feature_cols,
    target_name=target_col, class_names=['Non Legendary', 'Legendary']
)
viz_model.view(orientation='LR')

上記のコードはモデル全体ではなく、lightGBMの1番目の決定木のみを表しています。tree_indexの値を1としている箇所を2, 3と変更すると2番目、3番目の決定木を表示できます。

1本目の決定木の可視化

では、今回のモデルは何番目までの決定木があるのか確認します。それは以下のコードを使って確認できます。

model.num_trees()

上記を実行するとという値が11出力されました。つまり、決定木は11本連なっているようです。
lightGBMの仕組みや構造については下記の記事に詳しく書かれています。

データの予測根拠を確認する

dtreevizでは、あるデータがどのような経緯で最終的な予測になったかを可視化することができます。
今回のモデルではテストデータのindex 245のデータが予測が間違っています。このデータの予測根拠を見てましょう。下記を実行してみます。

target_index = 245
viz_model.view(x=X_test.iloc[target_index], show_just_path=True, orientation='LR')
index 245 の予測根拠1

では、1本目の決定木の値のみを使って予測根拠が説明できたといっていいのでしょうか?確認してみましょう。
各決定木の出力時点で予測スコアがどのように変化しているかを下記コードで確認します。

score = model.predict(X_test.iloc[target_index])
print(f'予測スコア:{score}')

for n in range(1, model.num_trees()+1):
    score = model.predict(X_test.iloc[target_index], num_iteration=n)
    print(f'{n}本目の決定木だけを使った時の予測スコア:{score}')
index 245の予測スコア

4本目時点で0.2までスコアが上がっています。1本目以外の決定木の予測根拠を確認して間違った理由を判断した方が良さそうです。

一方で、予測が正解しているindex 246を見てみましょう。

index 246 の予測根拠1
index 246の予測スコア

こちらは1本目で0.65のスコアになっており、以降の決定木でもスコアが上がり続けています。1本目の予測をほかの決定木も支持する方向性にはなっているので、すべての決定木を見なくても、その根拠は説明できそうです。

まとめ

lightGBMで作った予測モデルをdtreevizを使って可視化しました。可視化することによって、lightGBMの予測根拠を明らかにすることができますが、予測根拠を100%説明するには、多数の決定木を読み解く必要があります。
そこで、各決定木の出力する予測スコアを見ることで、どこまでの決定木を見れば予測根拠がわかったと言えそうかの判断の手がかりになりそうです。

この記事が気に入ったらサポートをしてみませんか?