見出し画像

STaR: 自己教示型推論器 - 推論による推論のブートストラップ

要旨

STaRは段階的な「思考の連鎖」根拠生成を通じて、複雑な推論タスクのパフォーマンスを向上させる新しい手法です。
下記の論文をまとめてみました。
2203.14465 (arxiv.org)

STaRの主要プロセス

  1. 少数ショットプロンプトで根拠生成

  2. 誤答の場合、正解を与えて再度根拠生成(合理化)

  3. 正解を導いた根拠でファインチューニング

  4. 反復

合理化とは

合理化は、モデルが正しく回答できなかった問題に対して、正解を提供し新しい根拠を生成させるプロセスです。これにより、モデルは逆方向に推論することができ、正解が与えられればより簡単に有用な根拠を生成できます。合理化によって生成された根拠は、ヒントなしで生成されたかのように訓練データに追加されます。

アルゴリズム

Algorithm 1: STaR
Input: M: 事前学習済みLLM; データセット D = {(xi, yi)}Di=1 (少数ショットプロンプト付き)

1: M0 ← M  # 元のモデルをコピー
2: for n in 1...N do  # 外部ループ
3:     (r̂i, ŷi) ← Mn-1(xi) ∀i ∈ [1, D]  # 根拠生成を実行
4:     (r̂rati, ŷrati) ← Mn-1(add_hint(xi, yi)) ∀i ∈ [1, D]  # 合理化を実行
5:     Dn ← {(xi, r̂i, yi) | i ∈ [1, D] ∧ ŷi = yi}  # 正解を使用して根拠をフィルタリング
6:     Dratn ← {(xi, r̂rati, yi) | i ∈ [1, D] ∧ ŷi ≠ yi ∧ ŷrati = yi}  # 合理化された根拠をフィルタリング
7:     Mn ← train(M, Dn ∪ Dratn)  # 正解の解決策で元のモデルをファインチューニング - 内部ループ
8: end for

4行目と6行目が合理化に対応しています。これらを除くと、合理化なしのSTaRになります。

Pythonコード

from sympy import symbols, Function, Sum, Eq, diff, exp
from typing import List, Tuple

# シンボルの定義
M, x, y, r_hat, y_hat = symbols('M x y r_hat y_hat')
p_M = Function('p_M')
E = Function('E')
indicator = Function('indicator')

def star_algorithm(M: Function, D: List[Tuple[symbols, symbols]], N: int):
    """
    STaRアルゴリズムの実装
    
    :param M: 言語モデル(SymPy関数として表現)
    :param D: データセット((x, y)のリスト)
    :param N: 反復回数
    """
    M_0 = M

    for n in range(1, N+1):
        D_n = []
        D_rat_n = []

        for x_i, y_i in D:
            # 根拠生成(実際のモデルでは、これは複雑な処理になります)
            r_hat_i, y_hat_i = M(x_i).subs(M, M_0)

            # 合理化(実際のモデルでは、これは複雑な処理になります)
            r_hat_rat_i, y_hat_rat_i = M(x_i, y_i).subs(M, M_0)

            # フィルタリング
            if y_hat_i == y_i:
                D_n.append((x_i, r_hat_i, y_i))
            elif y_hat_rat_i == y_i:
                D_rat_n.append((x_i, r_hat_rat_i, y_i))

        # 目的関数の定義
        J = Sum(E(indicator(y_hat == y) * p_M(y_hat, r_hat, x)), (x, y))

        # 勾配の計算
        grad_J = diff(J, M)

        # ここで、実際のモデルの更新を行います(簡略化のため省略)
        M_n = M + grad_J

        M_0 = M_n

    return M_0

# 使用例
x, y = symbols('x y')
M = Function('M')(x)
D = [(x, y) for _ in range(10)]  # サンプルデータセット
N = 5  # 反復回数

final_model = star_algorithm(M, D, N)
print(f"Final model: {final_model}")

数学的基礎

STaRは以下のRL型ポリシー勾配目的関数の近似として解釈できます:

式(1):

$$
J(M, X, Y) = \sum_i E_{\hat{y}_i,\hat{r}_i\sim p_M(\cdot|x_i)} 1(\hat{y}_i = y_i) 
$$

式(2):

$$
\nabla J(M, X, Y) = \sum_i E_{\hat{y}_i,\hat{r}_i\sim p_M(\cdot|x_i)} [1(\hat{y}_i = y_i) \cdot \nabla \log p_M(\hat{y}_i, \hat{r}_i | x_i)]
$$

ここで、

  • $${M}$$はモデル

  • $${X}$$は入力のセット($${x_i}$$はその要素)

  • $${Y}$$は正解のセット($${y_i}$$はその要素)

  • $${\hat{y}_i}$$は予測

  • $${\hat{r}_i}$$は生成された根拠

です。

実験結果

算術

16反復後の全体精度: 89.5%

CommonsenseQA

  • STaR (合理化あり): 72.5% (訓練データ86.7%使用)

  • GPT-J直接ファインチューニング: 60.0% (訓練データ100%使用)

GSM8K

STaRは従来手法を大幅に上回る性能を示しました。

合理化の役割

合理化により、モデルは以下の分布にアクセスできます:

$$
p(r | x, y)
$$

これは根拠$${r}$$の探索空間を改善し、式(1)の目的関数のオフポリシー推定として機能します。

温度の影響

高温サンプリングは以下の理由で逆効果です:

  1. 不正確な推論による正解の確率増加

  2. 悪質な推論での訓練による一般化阻害

結論

STaRは記号的・自然言語推論で大幅な性能向上を示し、多領域での応用可能性があります。ただし、以下の制限があります:

  1. 初期モデルの十分な能力要件

  2. 高チャンス性能タスクでの課題

これらの制限を克服することが今後の研究課題となります。

いいなと思ったら応援しよう!