dtreevizを使ってみた
dtreevizを使って決定木アルゴリズムを可視化してみます。決定木アルゴリズムとはlightGBMやRandomForestなどを言います。
実験用データとして、下記のポケモンデータを使用します。このポケモンデータのLegendaryを予測するモデルを作成し、その予測をdtreevizで可視化します。
データの確認
モデルの作成を行う前にデータを簡単に確認しておきましょう。
import polars as pl
df = pl.read_csv('Pokemon.csv')
df.head()
伝説ポケモンの数を確認します。
df['legendary'].value_counts()
予測モデルの作成
lightGBMで伝説ポケモンかどうかを予測するモデルを作ることにします。
テストデータと学習データに分けます。
さらに、特徴量として使う列と予測対象の列をセレクトしてlightGBMで扱える形に変換します。
df = df.sample(fraction=1, shuffle=True, seed=43)
test_size = 250
test, train = df.head(test_size), df.tail(-test_size)
feature_cols = ['hp', 'attack', 'defense', 'sp_attack', 'sp_defense', 'speed']
target_col = ['legendary']
X_train = train.select(feature_cols).cast(pl.Float32).to_pandas()
y_train = train.select(target_col).cast(pl.Float32).to_pandas()
X_test = test.select(feature_cols).cast(pl.Float32).to_pandas()
y_test = test.select(target_col).cast(pl.Float32).to_pandas()
dtrain = lgb.Dataset(X_train, y_train)
dvalid = lgb.Dataset(X_test, y_test)
次に、lightGBMのパラメータを設定して学習を実行します。
ここで、予測モデルの解釈性を高めるためにlightGBMのパラメータであるnum_leavesを7に設定しておきます。
params = {
'objective': 'binary', # 2値分類
'random_state': 43, # 乱数シード,
'num_leaves': 7,
'eta': 0.5,
'verbose': -1,
}
model = lgb.train(
params = params,
train_set = dtrain, # 学習用データセット
num_boost_round = 100, # Boostingの繰り返し回数
callbacks=[
lgb.early_stopping(stopping_rounds=10, verbose=True), # early_stopping用コールバック関数
lgb.log_evaluation(0)
],
valid_sets = [dvalid] # 評価用データセット
)
モデルができたので、精度を確認しておきます。
from sklearn.metrics import accuracy_score
y_pred = model.predict(X_test)
y_pred = [1.0 if pred > 0.5 else 0.0 for pred in y_pred]
print(f'Accuracy: {accuracy_score(y_test, y_pred)}')
ret = pl.DataFrame({
'index': [n for n in range(len(y_pred))],
'GrounTurth': y_test['legendary'].to_list(),
'Predict': y_pred
})
ret
テスト用データのAccuracyは0.94という結果でした。
dtreevizでモデルを可視化する
可視化は以下のコードを実行する。
import dtreeviz
viz_model = dtreeviz.model(
model,
tree_index=1,
X_train=X_train,
y_train=y_train[target_col[0]].astype(int),
feature_names=feature_cols,
target_name=target_col, class_names=['Non Legendary', 'Legendary']
)
viz_model.view(orientation='LR')
上記のコードはモデル全体ではなく、lightGBMの1番目の決定木のみを表しています。tree_indexの値を1としている箇所を2, 3と変更すると2番目、3番目の決定木を表示できます。
では、今回のモデルは何番目までの決定木があるのか確認します。それは以下のコードを使って確認できます。
model.num_trees()
上記を実行するとという値が11出力されました。つまり、決定木は11本連なっているようです。
lightGBMの仕組みや構造については下記の記事に詳しく書かれています。
データの予測根拠を確認する
dtreevizでは、あるデータがどのような経緯で最終的な予測になったかを可視化することができます。
今回のモデルではテストデータのindex 245のデータが予測が間違っています。このデータの予測根拠を見てましょう。下記を実行してみます。
target_index = 245
viz_model.view(x=X_test.iloc[target_index], show_just_path=True, orientation='LR')
では、1本目の決定木の値のみを使って予測根拠が説明できたといっていいのでしょうか?確認してみましょう。
各決定木の出力時点で予測スコアがどのように変化しているかを下記コードで確認します。
score = model.predict(X_test.iloc[target_index])
print(f'予測スコア:{score}')
for n in range(1, model.num_trees()+1):
score = model.predict(X_test.iloc[target_index], num_iteration=n)
print(f'{n}本目の決定木だけを使った時の予測スコア:{score}')
4本目時点で0.2までスコアが上がっています。1本目以外の決定木の予測根拠を確認して間違った理由を判断した方が良さそうです。
一方で、予測が正解しているindex 246を見てみましょう。
こちらは1本目で0.65のスコアになっており、以降の決定木でもスコアが上がり続けています。1本目の予測をほかの決定木も支持する方向性にはなっているので、すべての決定木を見なくても、その根拠は説明できそうです。
まとめ
lightGBMで作った予測モデルをdtreevizを使って可視化しました。可視化することによって、lightGBMの予測根拠を明らかにすることができますが、予測根拠を100%説明するには、多数の決定木を読み解く必要があります。
そこで、各決定木の出力する予測スコアを見ることで、どこまでの決定木を見れば予測根拠がわかったと言えそうかの判断の手がかりになりそうです。