STaR: 自己教示型推論器 - 推論による推論のブートストラップ
要旨
STaRは段階的な「思考の連鎖」根拠生成を通じて、複雑な推論タスクのパフォーマンスを向上させる新しい手法です。
下記の論文をまとめてみました。
2203.14465 (arxiv.org)
STaRの主要プロセス
少数ショットプロンプトで根拠生成
誤答の場合、正解を与えて再度根拠生成(合理化)
正解を導いた根拠でファインチューニング
反復
合理化とは
合理化は、モデルが正しく回答できなかった問題に対して、正解を提供し新しい根拠を生成させるプロセスです。これにより、モデルは逆方向に推論することができ、正解が与えられればより簡単に有用な根拠を生成できます。合理化によって生成された根拠は、ヒントなしで生成されたかのように訓練データに追加されます。
アルゴリズム
Algorithm 1: STaR
Input: M: 事前学習済みLLM; データセット D = {(xi, yi)}Di=1 (少数ショットプロンプト付き)
1: M0 ← M # 元のモデルをコピー
2: for n in 1...N do # 外部ループ
3: (r̂i, ŷi) ← Mn-1(xi) ∀i ∈ [1, D] # 根拠生成を実行
4: (r̂rati, ŷrati) ← Mn-1(add_hint(xi, yi)) ∀i ∈ [1, D] # 合理化を実行
5: Dn ← {(xi, r̂i, yi) | i ∈ [1, D] ∧ ŷi = yi} # 正解を使用して根拠をフィルタリング
6: Dratn ← {(xi, r̂rati, yi) | i ∈ [1, D] ∧ ŷi ≠ yi ∧ ŷrati = yi} # 合理化された根拠をフィルタリング
7: Mn ← train(M, Dn ∪ Dratn) # 正解の解決策で元のモデルをファインチューニング - 内部ループ
8: end for
4行目と6行目が合理化に対応しています。これらを除くと、合理化なしのSTaRになります。
Pythonコード
from sympy import symbols, Function, Sum, Eq, diff, exp
from typing import List, Tuple
# シンボルの定義
M, x, y, r_hat, y_hat = symbols('M x y r_hat y_hat')
p_M = Function('p_M')
E = Function('E')
indicator = Function('indicator')
def star_algorithm(M: Function, D: List[Tuple[symbols, symbols]], N: int):
"""
STaRアルゴリズムの実装
:param M: 言語モデル(SymPy関数として表現)
:param D: データセット((x, y)のリスト)
:param N: 反復回数
"""
M_0 = M
for n in range(1, N+1):
D_n = []
D_rat_n = []
for x_i, y_i in D:
# 根拠生成(実際のモデルでは、これは複雑な処理になります)
r_hat_i, y_hat_i = M(x_i).subs(M, M_0)
# 合理化(実際のモデルでは、これは複雑な処理になります)
r_hat_rat_i, y_hat_rat_i = M(x_i, y_i).subs(M, M_0)
# フィルタリング
if y_hat_i == y_i:
D_n.append((x_i, r_hat_i, y_i))
elif y_hat_rat_i == y_i:
D_rat_n.append((x_i, r_hat_rat_i, y_i))
# 目的関数の定義
J = Sum(E(indicator(y_hat == y) * p_M(y_hat, r_hat, x)), (x, y))
# 勾配の計算
grad_J = diff(J, M)
# ここで、実際のモデルの更新を行います(簡略化のため省略)
M_n = M + grad_J
M_0 = M_n
return M_0
# 使用例
x, y = symbols('x y')
M = Function('M')(x)
D = [(x, y) for _ in range(10)] # サンプルデータセット
N = 5 # 反復回数
final_model = star_algorithm(M, D, N)
print(f"Final model: {final_model}")
数学的基礎
STaRは以下のRL型ポリシー勾配目的関数の近似として解釈できます:
式(1):
$$
J(M, X, Y) = \sum_i E_{\hat{y}_i,\hat{r}_i\sim p_M(\cdot|x_i)} 1(\hat{y}_i = y_i)
$$
式(2):
$$
\nabla J(M, X, Y) = \sum_i E_{\hat{y}_i,\hat{r}_i\sim p_M(\cdot|x_i)} [1(\hat{y}_i = y_i) \cdot \nabla \log p_M(\hat{y}_i, \hat{r}_i | x_i)]
$$
ここで、
$${M}$$はモデル
$${X}$$は入力のセット($${x_i}$$はその要素)
$${Y}$$は正解のセット($${y_i}$$はその要素)
$${\hat{y}_i}$$は予測
$${\hat{r}_i}$$は生成された根拠
です。
実験結果
算術
16反復後の全体精度: 89.5%
CommonsenseQA
STaR (合理化あり): 72.5% (訓練データ86.7%使用)
GPT-J直接ファインチューニング: 60.0% (訓練データ100%使用)
GSM8K
STaRは従来手法を大幅に上回る性能を示しました。
合理化の役割
合理化により、モデルは以下の分布にアクセスできます:
$$
p(r | x, y)
$$
これは根拠$${r}$$の探索空間を改善し、式(1)の目的関数のオフポリシー推定として機能します。
温度の影響
高温サンプリングは以下の理由で逆効果です:
不正確な推論による正解の確率増加
悪質な推論での訓練による一般化阻害
結論
STaRは記号的・自然言語推論で大幅な性能向上を示し、多領域での応用可能性があります。ただし、以下の制限があります:
初期モデルの十分な能力要件
高チャンス性能タスクでの課題
これらの制限を克服することが今後の研究課題となります。