大規模言語モデルの比較メモ

WIP

比較するモデル

Training Dataset

Chinchilla

PaLM

LLaMA

Model Architecture, training setup

Chinchilla

Optimizer
AdanW
cosine schedule (cosine cycle length = 1.0 * num steps)

PaLM

780B tokens

標準的Transformer[Vaswani+, 2017]との違い

  • SwiGLU[Shazeer, 2020] activation function

    • Swish(xW)*xV

  • Parallel Layers [Wang and Komatsuzaki, 2021]

    • 通常:y = x + MLP(LayerNorm(x + Attention(LayerNorm(x)))

    • 並列:y = x + MLP(LayerNorm(x)) + Attention(LayerNorm(x))

    • 8Bでは少し性能低下、62Bでは性能低下無し

  • Multi-Query Attention

    • ?

  • RoPE Embeddings[Su+, 2021]

  • Shared Input-Output Embeddings

    • 過去の研究でも頻繁に採用されているテクニック

  • No Biases

    • training時に安定性が向上

  • Vocabulary

    • 256k tokenのSentencePiece

training setup

  • weight initialization

    • embeddingsとlayer norm scale以外はfan-in variance scalingで初期化

      • W ∼ N (0, 1/ √ nin)

        • nin = input dim of the kernel

    • embeddings

      • E~N(0,1)

    • 入出力のembeddings layerが共有されているため、softmax前のlogitsを 1/√nでスケーリング

      • n = embdding size

  • Optimizer

    • Adafactor (without factorization)

  • Optimization hyperparameters

    • learning rate = 10^-2

    • 最初の10,000stepsではlr=10^-2、その後1/√kで減衰させる

      • k=training steps

    • β1 = 0.9

    • β2=1.0-k^-0.8

      • 大型モデルでは標準的なβ2=0.99よりも安定性が高い

    • global norm gradient clipping = 1.0

    • weight decay = lr^2.0

  • Loss function

    • 標準的な言語モデリング損失(ラベル平滑化無し、全トークンの平均対数確率)

    • z_loss = 10^-4 * log2(Z) を補助損失として使用

      • softmax normarizer log(Z)が0に値被くように促す

      • これによってtrainingの安定性が向上

  • Sequence length

    • 2048

    • 入力例は連結され、2048トークンのsequenceに分割

      • そのためpaddingトークンは無いが、入力例は途中で分割される可能性あり

    • 入力例は特別な[eod]トークンで区切られる

  • Batch Size

    • training中にbatch sizeを増加させる

    • 最大モデルでは50k stepsまでbs=512 (1M tokens) 115k stepsまで bs=1024 (1M tokens)、最後にbs = 2048 (4M tokens)で、255k stepsでtrainingが完了するまで続ける

    • なぜ動的なbatch size scheduleをするのか?

      • 小さいbsはtraining初期においてはサンプル効率が高い、一方大きいbsはより良い勾配推定によりtraining後半に有益 [Smith+, 2018] [McCandlish+, 2018]

      • 大きいbsによりmatrix multiplication dimensionsが大きくなり、TPU効率が高まる

  • Dropout

    • Dropout無し

    • finetuneの時はdropout = 0.1

LLaMA

標準的Transformer[Vaswani+, 2017]との違い

  • Pre-normalization (GPT-3)

    • RMSNorm normalizing function[Ahang and Sennrich, 2019]を使用

  • SwiGLU[Shazeer, 2020] activation function (PaLM)

    • 次元数としてPaLMの4dの代わりに2/3 * 4dを使用

  • Rotary Embeddings[Su+, 2021] (GPTNeo)

    • 絶対的な位置embddingsを削除し、各レイヤーにrotary positional embeddings(RoPE)を追加

Optimizer
AdamW (β1=0.9, β2=0.95)
cosine schedule (最終的な値はmax/10)
weight decay = 0.1
gradient clipping = 1.0
warmup steps = 1000

Performance

Chinchilla


PaLM

LLaMA

いいなと思ったら応援しよう!