MeZOを利用してA100 80GBx1デバイスで7Bモデルの日本語ファインチューニングができるか試しました。
Llama2-7Bは高性能と言われていますが日本語性能はイマイチと言われてます。そのLLama2に見事日本語を覚えさせることができるか。1エポックで24時間かかりました。
正解は学習できているようです。
コマンドライン
$ MODEL=meta-llama/Llama-2-7b-chat-hf TASK=Chat MODE=prefix EPS=1e-1 LR=1e-3 CUDA_VISIBLE_DEVICES=0 nohup bash mezo.sh&
今回は自分で作った日本語会話モデル(GPT-3.5-Turboを使用)を学習させるために、tasks.pyとtemplates.pyを変更しました。
tasks.pyに以下のコードスニペットを追加
class ChatDataset(Dataset):
metric_name = "f1"
generation = True
def __init__(self, subtask=None, **kwargs) -> None:
self.load_dataset()
def load_dataset(self):
dataset = load_dataset("shi3z/Japanese_Wikipedia_Conversation")
train_examples = dataset["train"]["conversations"][100:]
valid_examples = dataset["train"]["conversations"][:100]
train_samples = [self.build_sample(example, idx) for idx, example in enumerate(train_examples)]
valid_samples = [self.build_sample(example, idx) for idx, example in enumerate(valid_examples)]
self.samples = {"train": train_samples, "valid": valid_samples}
def build_sample(self, example, idx):
s=""
for line in example[:-1]:
s+=line["from"]+": "+line["value"]+"\n"
lastline= example[-1]
s+=lastline["from"]+": "
return Sample(
id=idx,
data=s,
candidates=None,
correct_candidate=lastline["value"]
)
def get_template(self, template_version=0):
return {0: ChatTemplate}[template_version]()
templates.pyに以下のコードスニペットを追加
class ChatTemplate(Template):
def encode(self, sample):
return f"{sample}"
def verbalize(self, sample, candidate):
return f"{sample}{candidate}\n"
def encode_sfc(self, sample):
raise NotImplementedError
def verbalize_sfc(self, sample, candidate):
raise NotImplementedError
今回は面倒だったので会話のやり取りから最後の発言だけを予測するように学習しましたが、tasks.pyのbuild_sampleを会話のランダムな部分で区切ることでもっと汎化性能を上げられそうです。