Perplexityをもとに、複数の大規模言語モデルを切り替えて推論するシステムの簡単なコード実装



はじめに

最近は複数の大規模言語モデルを組み合わせて使用するシステム(Mixture of experts: MoE, Branch-Train-Merge: BTM)にハマっています。
特にBTMは、独立にモデルを訓練することができるので、「文学専用のAI」、「科学専用のAI」、…のように、専門家の集まりを、わかりやすく構築・統合可能です。
もちろんBTMにも、モデルの構築法、最後の統合方法などの諸課題があります。


今回は、モデルを統合するための簡単な実装コードを書いてみます。
最近は、普通にmergekitもあるようですが、勉強も兼ねた実装です。


Perplexityとは


データセットのクリーニングにも使えます。


アプローチ

与えられた入力文章に対するPerplexity(困惑さ)を指標に、使用するモデルを切り替えるシステムを作ります。

イメージ的には、

「文学を学んだモデルを作る」、「科学を学んだモデルを作る」
→「文学系の入力テキストを与える」
→「文学モデルは、文章に馴染み深い(Perplexityが小さい)」、「科学モデルは、文章に馴染みが薄い(Perplexityが大きい)」
→「文学モデルを用いる」

という流れでモデル選択が進みます。

BTMの初期モデルも、このアルゴリズムが使われている印象です。

実装例

モデルの事前訓練をする余裕がないので、今回は試しに、英語が得意なLLama2-7bと、日本語でファインチューニングしたElyza-7bを統合(merge)したシステムを作ってみようと思います。

英語の質問にはllama、日本語の質問にはelyzaで答えることができればコンセプト実証に成功です。

コード

関数とモデルの定義

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer,pipeline
import numpy as np

def perplexity(model, tokenizer, text) -> torch.Tensor:
    tokenized_input = tokenizer.encode(
        text, add_special_tokens=False, return_tensors="pt"
    ).to(model.device)
    with torch.inference_mode():
        output = model(tokenized_input, labels=tokenized_input)
    ppl = torch.exp(output.loss)
    return ppl.item()

class MoE:
    def __init__(self):
        self.models=[]
        self.coef=[]

    def set_coefs(self,coef):
        self.coef=coef

    def append_ELM(self,model,tokenizer):
        pipe=pipeline("text-generation",model=model,tokenizer=tokenizer,
                      max_new_tokens=100
                      )
        self.models.append((model,tokenizer,pipe))
        self.coef.append(1)

    def calc_perplexity(self,text):
        ppl_list=[]
        for model,tokenizer,_ in self.models:
            ppl_list.append(perplexity(model,tokenizer,text))

        return ppl_list

    def ask(self,text,verbose=True):
        ppl_array=np.array(self.calc_perplexity(text))
        ppl_array=ppl_array*np.array(self.coef)
        best_model_id=np.where(ppl_array==min(ppl_array))[0][0]
        if verbose:
            print("perplexity list")
            for i,ppl in enumerate(ppl_array):
                print(i,ppl)
            print(f"model id {best_model_id} is used")
        pipe=self.models[best_model_id][2]
        return pipe(text)[0]['generated_text']

モデルの登録

moe=MoE()


model_name_list =[ 
    "meta-llama/Llama-2-7b-chat-hf",
    "elyza/ELYZA-japanese-Llama-2-7b-instruct",
                  ]

for model_name in model_name_list:
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name, device_map="auto",
        torch_dtype=torch.float16
    )

    moe.append_ELM(model,tokenizer)

推論

moe.set_coefs([1,0])

text_list=[
    "hello, I'm John",
    "こんにちは",
    "how can we avoid global warming?",
    "地球温暖化を防ぐにはどうしたらいいか?",
      ]

for text in text_list:
    print("-----")
    print(text)
    response=moe.ask(text)
    print(response)

ポイント

モデルごとにperplexityの大小がかなり違っていたので、今回は便宜的に、coefで補正をかけました。

結果

model 0がllama2 (英語に強いモデル)
model 1がelyza (日本語に強いモデル)
です。

-----
Q. hello, I'm John
perplexity list
0 28.584484100341797
1 32.322213649749756
model id 0 is used
A. "Nice to meet you, John," I replied, trying to keep my voice steady.

"So, what brings you to this neck of the woods?" he asked, his eyes scanning my face.

I hesitated, unsure of how much I wanted to reveal. "Just visiting," I said vaguely.

"Ah, well, you're in luck," he said, grinning. "We've got

-----
Q. こんにちは
perplexity list
0 16.661598205566406
1 10.08517861366272
model id 1 is used
A. 今回は、「僕のヒーローアカデミア」の第22話のネタバレと感想を書いていきます。
第22話のあらすじ 桜並木の下で、爆笑しながら歩いている皆に呼び止められる。

-----
Q. how can we avoid global warming?
perplexity list
0 56.53865051269531
1 78.20221424102783
model id 0 is used
A. The world's top scientists have made it clear that global warming is one of the most pressing issues of our time, and that urgent action is needed to avoid the worst impacts of climate change. Here are some of the ways we can reduce our carbon footprint and help prevent global warming:

-----
Q. 地球温暖化を防ぐにはどうしたらいいか?
perplexity list
0 5.261678218841553
1 8.543850183486938
model id 0 is used
A. 地球温暖化を防ぐにはどうしたらいいか? 地球温暖化防止のためには、以下のような方策があります。
1. 石油、天然ガス、コールなどの非対称的なエネルギー源の使用を減少させます。
2. ソーラー、ウインド、水力な

最後の質疑以外は、狙ったエキスパートモデルが応答してくれました。


まとめ

perplexityを指標にすることで、与えられた質問ごとに、複数の言語モデルをいい感じにスイッチングできそうなことがわかりました。
スイッチングするアルゴリズム(ルーター)をニューラルネットにしたり、モデル切り替えをtokenレベルで行ってみるなど、諸々の改善はできそうです。

あるいは、既存のキットを使うのも良さそうです。

3/20追記: Transformersライブラリのpipelineへの埋め込み

transformersライブラリのpipelineから呼び出したかったので、一連の処理を埋め込みました。
かなりのハリボテ作業ですので、ご注意ください(GPT2クラスに偽装しています)。

class

from transformers import GPT2Config, GPT2Model
import torch
import numpy as np


#GPT2クラスを継承します。中身は空です。
class MoEWrapper(GPT2Model):

    config_class = GPT2Config
    #load_tf_weights = load_tf_weights_in_gpt2
    base_model_prefix = "transformer"
    is_parallelizable = True
    supports_gradient_checkpointing = True
    _no_split_modules = ["GPT2Block"]
    _skip_keys_device_placement = "past_key_values"

    verbose=True

    def __init__(self, *inputs, **kwargs):
        super().__init__(*inputs, **kwargs)
        self.model_list=[]


    def append_model(self,model):
        self.model_list.append(model)
    
    def set_tokenizer(self,tokenizer):
        self.tokenizer=tokenizer

    def set_model_id(self,model_id):
        self.model=self.model_list[model_id]

    def calc_perplexity(self,tokenized_input):
        ppl_list=[]
        for model in self.model_list:
            ppl_list.append(perplexity(model,tokenized_input))
        return np.array(ppl_list)


    # wrapper functions
  # generateのノリで、スイッチングする機能を実装すれば、forwardもできるはずです
    #def forward(self,*args, **kwargs):
    #    ret=self.model.forward(*args,**kwargs)
    #    return ret

    def generate(self,input_ids, attention_mask,
                  **generate_kwargs):

        ppl_array=self.calc_perplexity(input_ids)
        best_model_id=np.where(ppl_array==min(ppl_array))[0][0]
        self.set_model_id(best_model_id)
 
        if self.verbose:
            print(f"model {best_model_id} will be used")
            print("ppl array: ",ppl_array)


        ret=self.model.generate(input_ids=input_ids, 
                                attention_mask=attention_mask,
                                  **generate_kwargs)
        return ret


def perplexity(model, tokenized_input) -> torch.Tensor:
    with torch.inference_mode():
        output = model(tokenized_input, labels=tokenized_input)
    ppl = torch.exp(output.loss)
    return ppl.item()

呼び出し

from transformers import GPT2Config, GPT2Model,PreTrainedModel,PretrainedConfig
from MoEWrapper import MoEWrapper
cfg=PretrainedConfig()
cfg=GPT2Config()

moe=MoEWrapper(cfg)

model_name_list =[ 
    "モデル1",
    "モデル2",
                  ]

tokenizer = AutoTokenizer.from_pretrained(model_name_list[0])
for model_name in model_name_list:
    model = AutoModelForCausalLM.from_pretrained(
        model_name, device_map="auto",
        torch_dtype=torch.float16
    )
    moe.append_model(model)

#pipelineで使う
max_new_tokens=100
pipe=pipeline("text-generation",model=moe,
              tokenizer=tokenizer,
              max_new_tokens=max_new_tokens,
                       )

3/26追記 HuggingFaceへの登録


いいなと思ったら応援しよう!