見出し画像

多値分類の閾値調整: Optunaを用いた精度向上

多値分類のコンペが苦手だ。分類問題では後処理での閾値調整が有効らしいけど、その辺の理解がまだ不十分だった。
Kaggleの初心者向けコンペでoptunaを使った閾値調整が使われていたので、ChatGPTにも聞きながら整理してみた。


多値分類問題の閾値調整

多値分類モデルは、入力データが各カテゴリに属する確率を出力します。一般に、最も確率が高いカテゴリを予測結果として選びますが、この方法が最良とは限りません。実際にモデルを使う場面では、カテゴリを間違えたときの影響(コスト)はカテゴリによって異なります。例えば、医療診断で重大な病気を見逃すコストは、誤って病気と診断するコストよりもはるかに大きいです。このような場合、カテゴリごとに予測を決めるための基準(閾値)を変えることで、モデル全体の効果を高めることができます。

1. adjust_thresholds関数

予測された確率に基づいて、各クラスの閾値を超えるかどうかを判断し、適切なクラスを選択します。これにより、モデルの予測精度が向上することが期待されます。

# adjust_thresholds関数
def adjust_thresholds(pred_probs, thresholds):
    """
    予測確率に基づいてクラスを調整する。
    各サンプルに対して、設定された閾値以上の確率を持つクラスを予測クラスとする。
    
    Parameters:
    pred_probs: numpy.ndarray, 各クラスに対する予測確率の配列 (サンプル数 x クラス数)
    thresholds: list, クラスごとの閾値のリスト (クラス数)
    
    Returns:
    numpy.ndarray, 調整された予測クラスの配列
    """
    adjusted_preds = np.argmax(pred_probs >= thresholds, axis=1)
    return adjusted_preds

2. objective関数 (Optuna用)

F1スコアを最大化するために、クラスごとの最適な閾値を探索し、選択します。Optunaの強力な最適化機能を活用して、最も性能の良い閾値を見つけ出します。

# objective関数 for Optuna
def objective(trial, true_labels, pred_probs, num_classes):
    """
    Optunaの最適化プロセスで使用される目的関数。
    異なる閾値を試し、F1スコアを最大化する閾値の組み合わせを見つける。
    
    Parameters:
    trial: optuna.trial.Trial, Optunaの試行オブジェクト
    true_labels: numpy.ndarray, 実際のクラスラベルの配列
    pred_probs: numpy.ndarray, 各クラスに対する予測確率の配列 (サンプル数 x クラス数)
    num_classes: int, クラスの総数
    
    Returns:
    float, 試行におけるF1スコア
    """
    thresholds = [trial.suggest_uniform(f"threshold_{i}", 0.0, 1.0) for i in range(num_classes)]
    preds = np.argmax(pred_probs >= np.array(thresholds), axis=1)
    score = f1_score(true_labels, preds, average='macro')
    return score

3. find_optimal_thresholds関数

最適な閾値を見つけ出すこの関数は、上述したobjective関数を用いて、クラスごとの閾値を最適化します。Optunaの最適化フレームワークを活用し、試行錯誤を繰り返しながら最も性能が良い閾値を選択します。

# find_optimal_thresholds関数
def find_optimal_thresholds(true_labels, pred_probs, num_classes, n_trials=100):
    """
    F1スコアを最大化するための最適な閾値を見つける。
    
    Parameters:
    true_labels: numpy.ndarray, 実際のクラスラベルの配列
    pred_probs: numpy.ndarray, 各クラスに対する予測確率の配列 (サンプル数 x クラス数)
    num_classes: int, クラスの総数
    n_trials: int, 試行回数
    
    Returns:
    list, 各クラスにおける最適な閾値のリスト
    """
    study = optuna.create_study(direction='maximize')
    study.optimize(lambda trial: objective(trial, true_labels, pred_probs, num_classes), n_trials=n_trials)
    optimal_thresholds = [study.best_params[f"threshold_{i}"] for i in range(num_classes)]
    return optimal_thresholds

合成データセットでの検証

make_classification関数を使用して生成した合成データセットに対し、ロジスティック回帰モデルを適用して検証しました。閾値調整前の精度は0.68でしたが、閾値調整後は0.73へと改善されました。この結果は、適切な閾値調整がモデルの性能向上に効果的であることを示しています。

実行結果:
Original Accuracy: 0.6833333333333333
Adjusted Accuracy: 0.73
Optimal_Thresholds:
[0.3654747293820809, 0.32332056874795734, 0.5313018570044121]

!pip install scikit-learn optuna
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, f1_score
import numpy as np
import optuna
# 合成データセットの生成
X, y = make_classification(n_samples=1000, n_features=20, n_informative=15, n_redundant=5, n_classes=3, random_state=42)

# データを訓練セットとテストセットに分割
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# ロジスティック回帰モデルを訓練
clf = LogisticRegression(random_state=42, max_iter=200)
clf.fit(X_train, y_train)

# テストセットに対する予測確率を取得
pred_probs = clf.predict_proba(X_test)


# 最適な閾値を見つける
optimal_thresholds = find_optimal_thresholds(y_test, pred_probs, num_classes=3, n_trials=50)

# 閾値を使って予測クラスを調整
adjusted_preds = adjust_thresholds(pred_probs, optimal_thresholds)

# 評価(調整前と調整後の予測の精度を比較)
original_preds = np.argmax(pred_probs, axis=1)
original_accuracy = accuracy_score(y_test, original_preds)
adjusted_accuracy = accuracy_score(y_test, adjusted_preds)

print("Original Accuracy:", original_accuracy)
print("Adjusted Accuracy:", adjusted_accuracy)
print("Optimal_Thresholds:", optimal_thresholds)

注意点

  • 過学習のリスク: 閾値の最適化は、特に訓練データに対して過剰に適合するリスクを伴います。適切な交差検証を行うことで、このリスクを管理しましょう。

  • 問題設定の特性: 最適な閾値は、使用するデータセットや問題設定によって異なります。実験を重ね、最も効果的な閾値を見つけ出すことが重要です。

閾値調整は、多値分類問題における性能向上のための強力なツールです。Optunaなどのツールを活用することで、効率的に最適な閾値を見つけ出し、モデルの予測精度を向上させることが可能になります。



いいなと思ったら応援しよう!