大規模言語モデルの比較メモ
WIP
比較するモデル
Chinchilla [2203.15556] Training Compute-Optimal Large Language Models (arxiv.org)
PaLM [2204.02311] PaLM: Scaling Language Modeling with Pathways (arxiv.org)
LLaMA [2302.13971] LLaMA: Open and Efficient Foundation Language Models (arxiv.org)
Training Dataset
Chinchilla
PaLM
LLaMA
Model Architecture, training setup
Chinchilla
Optimizer
AdanW
cosine schedule (cosine cycle length = 1.0 * num steps)
PaLM
780B tokens
標準的Transformer[Vaswani+, 2017]との違い
SwiGLU[Shazeer, 2020] activation function
Swish(xW)*xV
Parallel Layers [Wang and Komatsuzaki, 2021]
通常:y = x + MLP(LayerNorm(x + Attention(LayerNorm(x)))
並列:y = x + MLP(LayerNorm(x)) + Attention(LayerNorm(x))
8Bでは少し性能低下、62Bでは性能低下無し
Multi-Query Attention
?
RoPE Embeddings[Su+, 2021]
Shared Input-Output Embeddings
過去の研究でも頻繁に採用されているテクニック
No Biases
training時に安定性が向上
Vocabulary
256k tokenのSentencePiece
training setup
weight initialization
embeddingsとlayer norm scale以外はfan-in variance scalingで初期化
W ∼ N (0, 1/ √ nin)
nin = input dim of the kernel
embeddings
E~N(0,1)
入出力のembeddings layerが共有されているため、softmax前のlogitsを 1/√nでスケーリング
n = embdding size
Optimizer
Adafactor (without factorization)
Optimization hyperparameters
learning rate = 10^-2
最初の10,000stepsではlr=10^-2、その後1/√kで減衰させる
k=training steps
β1 = 0.9
β2=1.0-k^-0.8
大型モデルでは標準的なβ2=0.99よりも安定性が高い
global norm gradient clipping = 1.0
weight decay = lr^2.0
Loss function
標準的な言語モデリング損失(ラベル平滑化無し、全トークンの平均対数確率)
z_loss = 10^-4 * log2(Z) を補助損失として使用
softmax normarizer log(Z)が0に値被くように促す
これによってtrainingの安定性が向上
Sequence length
2048
入力例は連結され、2048トークンのsequenceに分割
そのためpaddingトークンは無いが、入力例は途中で分割される可能性あり
入力例は特別な[eod]トークンで区切られる
Batch Size
training中にbatch sizeを増加させる
最大モデルでは50k stepsまでbs=512 (1M tokens) 115k stepsまで bs=1024 (1M tokens)、最後にbs = 2048 (4M tokens)で、255k stepsでtrainingが完了するまで続ける
なぜ動的なbatch size scheduleをするのか?
小さいbsはtraining初期においてはサンプル効率が高い、一方大きいbsはより良い勾配推定によりtraining後半に有益 [Smith+, 2018] [McCandlish+, 2018]
大きいbsによりmatrix multiplication dimensionsが大きくなり、TPU効率が高まる
Dropout
Dropout無し
finetuneの時はdropout = 0.1
LLaMA
標準的Transformer[Vaswani+, 2017]との違い
Pre-normalization (GPT-3)
RMSNorm normalizing function[Ahang and Sennrich, 2019]を使用
SwiGLU[Shazeer, 2020] activation function (PaLM)
次元数としてPaLMの4dの代わりに2/3 * 4dを使用
Rotary Embeddings[Su+, 2021] (GPTNeo)
絶対的な位置embddingsを削除し、各レイヤーにrotary positional embeddings(RoPE)を追加
Optimizer
AdamW (β1=0.9, β2=0.95)
cosine schedule (最終的な値はmax/10)
weight decay = 0.1
gradient clipping = 1.0
warmup steps = 1000