見出し画像

MeZOによるLlama2-7Bモデルの日本語ファインチューニング

MeZOを利用してA100 80GBx1デバイスで7Bモデルの日本語ファインチューニングができるか試しました。

Llama2-7Bは高性能と言われていますが日本語性能はイマイチと言われてます。そのLLama2に見事日本語を覚えさせることができるか。1エポックで24時間かかりました。

0%| | 0/100 [00:00<?, ?it/s]2023-11-14 11:06:41,509 - INFO - ========= Example =========
2023-11-14 11:06:41,509 - INFO - Candidate: None
2023-11-14 11:06:41,509 - INFO - Correct candidate: この道は四国八十八箇所巡礼の遍路道のひとつとして使われています。具体的には金剛福寺から延光寺へ向かう遍路道です。
2023-11-14 11:06:42,357 - INFO - === Prompt ===
2023-11-14 11:06:42,358 - INFO - <s>Sample(id=26, data='human: この高知県道21号土佐清水宿毛線は、どこからどこまでの道なんですか?\ngpt: この道は高知県土佐清水市から宿毛市に至る道です。\nhuman: この道は何の道として使われているんですか?\ngpt: ', correct_candidate='この道は四国八十八箇所巡礼の遍路道のひとつとして使われています。具体的には金剛福寺から延光寺へ向かう遍路道です。', candidates=None)
2023-11-14 11:06:42,358 - INFO - Output: この道は四国八十八箇所巡礼の遍路道のひとつとして使われています。具体的には金��
1%| | 1/100 [00:00<01:24, 1.18it/s]2023-11-14 11:06:42,358 - INFO - ========= Example =========
2023-11-14 11:06:42,358 - INFO - Candidate: None
2023-11-14 11:06:42,358 - INFO - Correct candidate: 波号第二百七潜水艦は向後崎西方沖でアメリカ海軍によって海没処分されました。ここで廃棄されることになりました。
2023-11-14 11:06:43,252 - INFO - === Prompt ===
2023-11-14 11:06:43,255 - INFO - <s>Sample(id=86, data='human: 先生、波号第二百七潜水艦はどんな潜水艦だったんですか?\ngpt: 波号第二百七潜水艦は、日本海軍の潜水艦でした。波二百一型潜水艦の7番艦として建造され、太平洋戦争末期に竣工しました。しかし、実際には外海に出撃することなく、戦後に海没処分されてしまいました。\nhuman: 波号第二百七潜水艦の艦歴について教えてください。\ngpt: はい、波号第二百七潜水艦はマル戦計画の潜水艦小、第4911号艦型の7番艦として計画されました。起工は1945年4月23日で、同年5月1日に命名された後、8月14日に竣工しました。終戦時の所在地は佐世保で、11月30日に海軍省の廃止に伴い除籍されました。そして、1946年4月5日にアメリカ海軍により海没処分されました。\nhuman: 波号第二百七潜水艦は最終的にどこで処分されたんですか?\ngpt: ', correct_candidate='波号第二百七潜水艦は向後崎西方沖でアメリカ海軍によって海没処分されました。ここで廃棄されることになりました。', candidates=None)
2023-11-14 11:06:43,255 - INFO - Output: 波号第二百七潜水艦は向後崎西方沖でアメリカ海軍によって海没処分されました。ここで��
2%|▏ | 2/100 [00:01<01:25, 1.14it/s]2023-11-14 11:06:43,255 - INFO - ========= Example =========
2023-11-14 11:06:43,255 - INFO - Candidate: None
2023-11-14 11:06:43,255 - INFO - Correct candidate: エレーナ・クジミナは1979年10月15日に亡くなりました。
2023-11-14 11:06:43,800 - INFO - === Prompt ===
2023-11-14 11:06:43,801 - INFO - <s>Sample(id=2, data='human: エレーナ・クジミナはどのような職業でしたか?\ngpt: エレーナ・クジミナはロシアの女優でした。\nhuman: エレーナ・クジミナの生年月日はいつですか?\ngpt: エレーナ・クジミナの生年月日は1909年2月17日です。\nhuman: エレーナ・クジミナはいつ亡くなりましたか?\ngpt: ', correct_candidate='エレーナ・クジミナは1979年10月15日に亡くなりました。', candidates=None)
2023-11-14 11:06:43,801 - INFO - Output: エレーナ・クジミナは1979年10月15日に亡くなりました。
2023-11-14 11:07:55,189 - INFO - Candidate: None
2023-11-14 11:07:55,189 - INFO - Correct candidate: 彼は身長208cmのストレッチ・フォー型の選手で、アウトサイドからのシュートが得意です。デフェンス面ではムラがあると書かれています。
2023-11-14 11:07:56,042 - INFO - === Prompt ===
2023-11-14 11:07:56,043 - INFO - <s>Sample(id=12783, data='human: 先生、ライアン・アンダーソンはどこの大学でプレイしていたんですか?\ngpt: ライアン・アンダーソンはカリフォルニア大学でプレイしていました。\nhuman: 先生、彼はどのチームでNBAデビューしましたか?\ngpt: 彼はニュージャージー・ネッツでNBAデビューしました。\nhuman: 先生、彼の個人成績について教えてください。\ngpt: 彼の個人成績については文中に詳しく書かれていないのでわかりません。\nhuman: 先生、彼はどんなプレースタイルの選手なんですか?\ngpt: ', correct_candidate='彼は身長208cmのストレッチ・フォー型の選手で、アウトサイドからのシュートが得意です。デフェンス面ではムラがあると書かれています。', candidates=None)
2023-11-14 11:07:56,043 - INFO - Output: 彼は身長208cmのストレッチ・フォー型の選手で、アウトサイドからのシュートが得意です。デフェンス

正解は学習できているようです。

コマンドライン

 $ MODEL=meta-llama/Llama-2-7b-chat-hf TASK=Chat MODE=prefix EPS=1e-1 LR=1e-3 CUDA_VISIBLE_DEVICES=0 nohup bash mezo.sh& 

今回は自分で作った日本語会話モデル(GPT-3.5-Turboを使用)を学習させるために、tasks.pyとtemplates.pyを変更しました。

tasks.pyに以下のコードスニペットを追加

class ChatDataset(Dataset):
    metric_name = "f1"
    generation = True

    def __init__(self, subtask=None, **kwargs) -> None:
        self.load_dataset()
        
    def load_dataset(self):
        dataset = load_dataset("shi3z/Japanese_Wikipedia_Conversation")
        train_examples = dataset["train"]["conversations"][100:]
        valid_examples = dataset["train"]["conversations"][:100]
        #for tr in dataset['train']:
        #    print(tr)

        train_samples = [self.build_sample(example, idx) for idx, example in enumerate(train_examples)]
        valid_samples = [self.build_sample(example, idx) for idx, example in enumerate(valid_examples)]
        self.samples = {"train": train_samples, "valid": valid_samples}
    
    # for generative tasks, candidates are []
    def build_sample(self, example, idx):
        s=""
        for line in example[:-1]:
            s+=line["from"]+": "+line["value"]+"\n"
        lastline= example[-1]
        s+=lastline["from"]+": "
        
        return Sample(
            id=idx,
            data=s,
            candidates=None,
            correct_candidate=lastline["value"]
        )
        
    def get_template(self, template_version=0):
        return {0: ChatTemplate}[template_version]()

templates.pyに以下のコードスニペットを追加

class ChatTemplate(Template):

    def encode(self, sample):

        return f"{sample}"

    def verbalize(self, sample, candidate):

        return f"{sample}{candidate}\n"
    
    def encode_sfc(self, sample):
        raise NotImplementedError

    def verbalize_sfc(self, sample, candidate):
        raise NotImplementedError

今回は面倒だったので会話のやり取りから最後の発言だけを予測するように学習しましたが、tasks.pyのbuild_sampleを会話のランダムな部分で区切ることでもっと汎化性能を上げられそうです。