[Guidance#5]任意の正規表現を強制できるpattern guidesを試してみる
Guidanceの一つの強みである生成結果の強制。これまではプロンプトを工夫して、いかに期待する結果の生成確率を高めるかということが行われてきましたが、それとはまた違ったアプローチのようです。
pattern guidesついて公式notebookがあったので、そちらを少し参考に試してみました。またGuidanceでは実際にどのような実装がなされているのかを知るべく、内部処理について調査してみました。
pattern guidesを試す
今回は2000年から2020年までのいずれかの日付をYYYYMMDD形式で生成してもらうタスクを想定して試してみたいと思います。
検証にはgpt-2を用い、pattern guidesありとなしで実際の生成物をみていきます。
以下のColabで試せます
プログラムは次のとおりになります。
!pip install guidance transformers
!pip show guidance
> Version: 0.0.51
import guidance
llm = guidance.llms.Transformers("gpt2")
guidance.llms.Transformers.cache.clear()
# パターンガイドなしの無効な出力
program = guidance("""Generate a date from 2000 to 2020 in 'YYYY-MM-DD' format.
{{gen 'date' max_tokens=10}}""", llm=llm)
# プログラムの実行
output = program()
print("\nパターンガイドなしで生成されたdate\n", output["date"])
# パターンガイドありの有効な出力
program = guidance("""Generate a date from 2000 to 2020 in 'YYYY-MM-DD' format.
{{gen 'date' max_tokens=10 pattern="20(00|01|02|03|04|05|06|07|08|09|10|11|12|13|14|15|16|17|18|19|20)-((0[1-9])|(1[0-2]))-((0[1-9])|([12][0-9])|(3[01]))"}}""", llm=llm)
# プログラムの実行
output = program()
print("\nパターンガイドありで生成されたdate\n", output["date"])
pattern指定としては、ゴリ押し感はありますが以下のような正規表現で2000年1月1日〜2020年12月31日までの適切な日付を取得するようにしています。
20(00|01|02|03|04|05|06|07|08|09|10|11|12|13|14|15|16|17|18|19|20)-((0[1-9])|(1[0-2]))-((0[1-9])|([12][0-9])|(3[01]))
結果の比較
ちゃんと指定した正規表現を強制できていますね!
どのように内部実装されているのか
guidance.llms._transformers.RegexLogitsProcessorがその正体のようです。
簡潔にいうと、RegexLogitsProcessorは、モデルが生成する候補トークンを正規表現パターンに従ってフィルタリングを行い、パターンマッチするトークンのロジットスコア(モデルの出力層のスコア)を大幅に上げることで、その後のsoftmax関数で意図するトークンが選択(生成)されやすくしています。
具体例を交えて処理の流れを理解する
システムがプロンプト "The weather is " を受け取ったとします。
次のトークンを選ぶために、モデルは各可能なトークン("hot", "cold", "nice", ...)のロジットスコアを計算します。たとえば、それぞれのロジットスコアが {"hot": 2.4, "cold": 0.7, "nice": 1.3, ...} のようになったとします。
ここで RegexLogitsProcessor が効果を発揮します。仮に私たちが "The weather is (hot|cold)" というパターンに一致する出力を目指しているとしましょう。
このプロセッサは、次に来るべきトークンの選択が "hot" と "cold" になるようにロジットスコアを調整します。具体的には、これらのトークンのロジットスコアを大幅に増加させ、他のトークン(この場合は"nice"など)のロジットスコアは相対的に小さくなります。つまり、ロジットスコアは {"hot": 2.4 + 10, "cold": 0.7 + 10, "nice": 1.3, ...} のようになります。
この修正されたロジットスコアに基づいて、次のトークンが選ばれます。
貪欲な選択(temperature==0)をしている場合、最もロジットスコアが高い "hot" が選ばれます。
一方、確率的にサンプリングをしている場合、"hot" と "cold" のいずれかが選ばれる確率が大幅に上昇し、"nice" などの他のトークンが選ばれる確率はほぼ0になります。しかし、"hot" と "cold" のどちらが選ばれるかは確率的に決定されます。
これにより、生成プロセスは特定のパターンに一致するように誘導され、結果的に "The weather is hot" または "The weather is cold" のような出力が得られます。
コードベースで理解する
主な処理部分を見てみましょう。
# _transformers.py L508
# compute the bias values
self.bias_vector[:] = 0
sort_inds = torch.argsort(scores, 1, True)
to_bias = []
for i in range(min(sort_inds.shape[1], self.max_consider)):
proposed_string = (self.current_strings[0] + self.decode([sort_inds[0,i]]))[self.forced_chars:]
m = self.pattern.fullmatch(proposed_string, partial=True)
if m:
to_bias.append(int(sort_inds[0, i]))
if self.is_greedy:
break
ここでは、最初に候補のトークンをロジットスコア(scores)に基づいてソートします。次に、上位の候補トークンを順に見ていき、それらが正規表現パターンに一致するかどうかを確認します。一致する場合、そのトークンのインデックスがto_biasリストに追加されます。
is_greedy(temperature==0)であれば、パターンマッチした一番ロジットスコアの高いトークンを見つけ次第、breakをして決定論的な選択を行っています。
そして、後続の処理を見てみると、
# _transformers.py L524
# bias allowed tokens
min_to_bias = float(scores[0, to_bias].min())
bias_value = scores[0, sort_inds[0, 0]] - min_to_bias + 10
for x in to_bias:
self.bias_vector[x] = bias_value
return scores + self.bias_vector.to(scores.device)
ここで、to_biasリストに含まれるトークンのロジットスコアにバイアスが追加されます。つまり、これらのトークンの選択確率が上がります。バイアスは、最も高いロジットスコアを持つトークンのスコアから最小のバイアス対象トークンスコアを引いた値に10を加えたものです。
このプロセスにより、バイアスが追加されたトークン(パターンに一致するトークン)は、他のトークン(パターンに一致しないトークン)よりも選択されやすくなります。
ただし、一致しないトークンは完全に排除されるわけではなく、一部のトークンはまだ選択の可能性が残されています。ただし、その確率は一致するトークンよりもはるかに低いです。
利用できるケース
LLMの出力層のスコア部分に対してのアプローチであるため、基本的にはLLM APIに対しては利用することは難しそうです。
コードを見る限り、guidance.llms.Transformersで扱えるモデルのみのようでした(huggingfaceにあるモデルや、vicunaなどのOSS LLM)。
所感
期待する出力結果を得る方法として、これまではプロンプトエンジニアリングによって依存していた部分に、強力な新しい選択肢が加わりました。個人的に、パターンマッチはLLMを機能として活用する上で、かなり有効な手になってくると思います。
まだ明確なユースケースについては出会っていませんが、今後商用利用可能な有力なOSS LLMが出てきた時に、OpenAIなどのLLMモデルと、どちらを利用するかの判断軸として効いてきそうな気がしました。