【実装例】内積計算プログラムの高速化
超高性能プログラミング技術のメモ(13)
技術を忘れないように、書き残しています。
前回、内積計算を例に演算とメモリ転送のオーバーラップを説明しました。今回は、内積計算をちゃんと実装して、速度の比較をしてみました。
ベースコードの作成
今回は、次のようなプログラムをチューニングしていくことにします。
double ddot(size_t n, const double* x, const double* y)
{
double value = 0e0;
while( n-- ){ value += (*(x++))*(*(y++)); }
return value;
}
データ型は倍精度doubleとし、メモリアクセスパターンはシーケンシャルパターンを前提にします。doubleの内積(=dot積)なので、ddotという関数名にしています。チューニングする方の関数名はfast_ddotにしておきます。
見ての通り、内積計算はとっても簡単なプログラムです。
では、これを少しずつ変形していきましょう。
アライメント判定値を作成する
C言語で記述するうちはあまり関係ありませんが、アセンブリ言語ではアライメントが境界に揃っているかどうかで使用できる命令が変わります。そこで、ベクトルxとベクトルyのアドレスから、次のような場合分けを行います。
b00 : xとyの両方のアドレスが、境界に揃っている
b01 : yのアドレスだけが、境界に揃っている
b10 : xのアドレスだけが、境界に揃っている
b11 : xとyの両方のアドレスが、境界に揃っていない
ここで、bXXはビットを表します。このようなビットを計算するために、次のようなコードを追加します。ALIGN_CHECKは、32ビット境界であれば0x1f(b1111)に、64ビット境界であれば0x3f(b11111)にセットしたマクロ変数です。
uint64_t xalign = ((uint64_t)x) & ALIGN_CHECK;
uint64_t yalign = ((uint64_t)y) & ALIGN_CHECK;
uint64_t align = ((((yalign>0)?1:0)<<1) | ((xalign>0)?1:0));
xalignは、ポインタxのアドレスを64ビット整数uint64_tに型変換し、下位ビットをALIGN_CHECKとのAND演算によって取り出しています。yalignも同様です。xalign=0であればアライメント境界と揃っているので0を、そうでなければ1をalignにセットしています。yalignも、同様の計算をして、alignの2ビット目に値をセットしています。
レジスタ数分の一時変数を用意する
CPUが保有するレジスタ数を超えた数の一時変数を使用すると、レジスタの値をキャッシュに一時退避する処理(レジスタスピル)がコンパイラによって差し込まれます。
レジスタスピルは実行速度を急激に低下させるので、一時変数数がレジスタ数を超えないようにします。具体的には、レジスタ数を同じ数の一時変数を宣言しておき、それ以外を使わないようにします。
/*
YMM regiser variables
*/
double ymm0,ymm1,ymm2,ymm3;
double ymm4,ymm5,ymm6,ymm7;
double ymm8,ymm9,ymm10,ymm11;
double ymm12,ymm13,ymm14,ymm15;
ここでは、x86_64アーキテクチャを前提とし、16個のYMMレジスタが使えるものとします。ループインデックス変数や戻り値変数には、YMMレジスタは使われないので、これだけで十分です。
こうすることで、コンパイラにおけるレジスタ割り当てをコントロールします。特に、違う変数を使うことで、明示的に別のレジスタを使用するように指示できます。
内積計算を演算毎に分ける
アセンブリ言語は一演算につき一命令ですので、C言語でも一演算毎に計算を分割しておきましょう。こうすることで、コンパイラが狙い通りに解釈してくれるようになります。
double value = 0e0;
if( !align ){
/* aligiment case 0 */
ymm12 = 0e0;
while( n-- ){
ymm0 = *x; // Load
ymm4 = *y; // Load
ymm8 = ymm0 * ymm4; // Mul
ymm12 = ymm12 + ymm8; // Add
x++;
y++;
}
value = ymm12;
}else{
/* aligiment case 1,2 or 3 */
ymm12 = 0e0;
while( n-- ){
ymm0 = *x; // Load
ymm4 = *y; // Load
ymm8 = ymm0 * ymm4; // Mul
ymm12 = ymm12 + ymm8; // Add
x++;
y++;
}
value = ymm12;
}
return value;
分割の結果、ポインタから値を取り出すロード処理が2個、掛け算が1個、足し算が1個、ポイントのインクリメントが2個になりました。上記は、アライメント判定値が0の場合と、それ以外の場合に分けています。
レジスタ数の限界までアンロールする
上記のコードでは、用意した一時変数のうち、ymm0, ymm4, ymm8, ymm12の4つを使いました。まだ、12個の一時変数が未使用なので、これらをすべて使用するように、ループアンローリングを行いましょう。アンロールの段数は、16÷4=4段が適切と考えられます。
double value = 0e0;
if( !align ){
/* aligiment case 0 */
ymm12 = 0e0;
ymm13 = 0e0;
ymm14 = 0e0;
ymm15 = 0e0;
n_unroll = (n>>2); // Unroll by 4 elements
while( n_unroll-- ){
ymm0 = *(x ); // Load
ymm1 = *(x+1); // Load
ymm2 = *(x+2); // Load
ymm3 = *(x+3); // Load
ymm4 = *(y ); // Load
ymm5 = *(y+1); // Load
ymm6 = *(y+2); // Load
ymm7 = *(y+3); // Load
ymm8 = ymm0 * ymm4; // Mul
ymm9 = ymm1 * ymm5; // Mul
ymm10 = ymm2 * ymm6; // Mul
ymm11 = ymm3 * ymm7; // Mul
ymm12 = ymm12 + ymm8; // Add
ymm13 = ymm13 + ymm9; // Add
ymm14 = ymm14 + ymm10;// Add
ymm15 = ymm15 + ymm11;// Add
x+=4;
y+=4;
}
if( n & 2 ){
ymm0 = *(x ); // Load
ymm1 = *(x+1); // Load
ymm4 = *(y ); // Load
ymm5 = *(y+1); // Load
ymm8 = ymm0 * ymm4; // Mul
ymm9 = ymm1 * ymm5; // Mul
ymm12 = ymm12 + ymm8; // Add
ymm13 = ymm13 + ymm9; // Add
x+=2;
y+=2;
}
if( n & 1 ){
ymm0 = *x; // Load
ymm4 = *y; // Load
ymm8 = ymm0 * ymm4; // Mul
ymm12 = ymm12 + ymm8; // Add
x++;
y++;
}
value = ymm12 + ymm13 + ymm14 + ymm15;
}else{
上記コードは、アライメント判定値が0の場合だけです。それ以外の場合も同じコードになるので、省略しました。
4段でアンロールしたので、ループの反復回数も1/4にしなければなりません。そこで、反復数nを2ビット右にシフトさせて、4で割った数n_unrollを用意しています。しかし、そうするとn=1,2,3の場合が計算されません。そこで、nの2ビット目が1だったとき、1ビット目が1だったときの端数処理を加えています。
さて、こうすることで、メインループでは、レジスタに相当する全ての一時変数を使い切ることができました。ロード処理が8個、掛け算が4個、足し算が4個となりました。
ここまでで速くなったのか確認して見る
次は、アセンブラコードに変更する前に、一度パフォーマンスを測定してみましょう。
上図の、ddotは最初のシンプルなコード、fast_ddotはアンローリングまで終わったコードです。どちらも、C言語のコードになります。横軸はベクトルサイズで、1,024要素から2倍を繰り返し、536,870,912要素までを測定しています。縦軸は、浮動小数点演算数FLOPを測定時間で割った値FLOP/sです。
65536要素までは、おそらくL2キャッシュに乗ったため、だいたい平均3〜4倍高速になっています。131072要素からは、徐々にパフォーマンスが落ちていき、L3キャッシュに乗って、おおよそ2倍程度高速化してます。536,870,912要素では、メモリ容量を超えるため、劇的に遅くなっています。
というわけで、C言語だけでも平均3〜4倍は速くなりました。
このアンローリングしたC言語コードをベースコードddotとして、次からはインラインアセンブリに変更してみます。
アセンブラコードの作成
C言語の段階で、内積計算はメモリの影響が大きいことが分かりますが、もう少し速くなるかもしれないので、アセンブリ言語で実装してみます。とはいえ、全部をアセンブリ言語にすると面倒なので、インラインアセンブラを使います。
ベクトル化する
YMMレジスタはdouble型の数値を4つ持てる(=ベクトル長が4)ので、さらにアンロール段数を4倍にできます。そこで、さらにアンロールしつつ、アセンブラコードに直しましょう。
/********************************************************
ALIGNMENT CASE 0
********************************************************/
if( !align ){
//ymm12 = 0e0;
//ymm13 = 0e0;
//ymm14 = 0e0;
//ymm15 = 0e0;
__asm__ __volatile__(
"vpxor %%ymm12, %%ymm12, %%ymm12 \n\t"
"vpxor %%ymm13, %%ymm13, %%ymm13 \n\t"
"vpxor %%ymm14, %%ymm14, %%ymm14 \n\t"
"vpxor %%ymm15, %%ymm15, %%ymm15 \n\t"
::);
n_unroll = (n>>4); // Unroll by 16 elements
while( n_unroll-- ){
//ymm0 = *(x ); // Load
//ymm1 = *(x+1); // Load
//ymm2 = *(x+2); // Load
//ymm3 = *(x+3); // Load
//ymm4 = *(y ); // Load
//ymm5 = *(y+1); // Load
//ymm6 = *(y+2); // Load
//ymm7 = *(y+3); // Load
//ymm8 = ymm0 * ymm4; // Mul
//ymm9 = ymm1 * ymm5; // Mul
//ymm10 = ymm2 * ymm6; // Mul
//ymm11 = ymm3 * ymm7; // Mul
//ymm12 = ymm12 + ymm8; // Add
//ymm13 = ymm13 + ymm9; // Add
//ymm14 = ymm14 + ymm10;// Add
//ymm15 = ymm15 + ymm11;// Add
//x+=16;
//y+=16;
__asm__ __volatile__(
"\n\t"
"vmovapd 0*8(%[x]), %%ymm0 \n\t"
"vmovapd 4*8(%[x]), %%ymm1 \n\t"
"vmovapd 8*8(%[x]), %%ymm2 \n\t"
"vmovapd 12*8(%[x]), %%ymm3 \n\t"
"vmovapd 0*8(%[y]), %%ymm4 \n\t"
"vmovapd 4*8(%[y]), %%ymm5 \n\t"
"vmovapd 8*8(%[y]), %%ymm6 \n\t"
"vmovapd 12*8(%[y]), %%ymm7 \n\t"
"vmulpd %%ymm0 , %%ymm4 , %%ymm8 \n\t"
"vmulpd %%ymm1 , %%ymm5 , %%ymm9 \n\t"
"vmulpd %%ymm2 , %%ymm6 , %%ymm10\n\t"
"vmulpd %%ymm3 , %%ymm7 , %%ymm11\n\t"
"vaddpd %%ymm8 , %%ymm12, %%ymm12\n\t"
"vaddpd %%ymm9 , %%ymm13, %%ymm13\n\t"
"vaddpd %%ymm10, %%ymm14, %%ymm14\n\t"
"vaddpd %%ymm11, %%ymm15, %%ymm15\n\t"
"\n\t"
"subq $-16*8, %[x]\n\t"
"subq $-16*8, %[y]\n\t"
"\n\t"
:[x]"=r"(x),[y]"=r"(y)
:"0"(x),"1"(y)
);
}
if( n & 8 ){
//ymm0 = *(x ); // Load
//ymm1 = *(x+1); // Load
//ymm4 = *(y ); // Load
//ymm5 = *(y+1); // Load
//ymm8 = ymm0 * ymm4; // Mul
//ymm9 = ymm1 * ymm5; // Mul
//ymm12 = ymm12 + ymm8; // Add
//ymm13 = ymm13 + ymm9; // Add
//x+=8;
//y+=8;
__asm__ __volatile__(
"\n\t"
"vmovapd 0*8(%[x]), %%ymm0 \n\t"
"vmovapd 4*8(%[x]), %%ymm1 \n\t"
"vmovapd 0*8(%[y]), %%ymm4 \n\t"
"vmovapd 4*8(%[y]), %%ymm5 \n\t"
"vmulpd %%ymm0 , %%ymm4 , %%ymm8 \n\t"
"vmulpd %%ymm1 , %%ymm5 , %%ymm9 \n\t"
"vaddpd %%ymm8 , %%ymm12, %%ymm12\n\t"
"vaddpd %%ymm9 , %%ymm13, %%ymm13\n\t"
"\n\t"
"subq $-8*8, %[x]\n\t"
"subq $-8*8, %[y]\n\t"
"\n\t"
:[x]"=r"(x),[y]"=r"(y)
:"0"(x),"1"(y)
);
}
if( n & 4 ){
//ymm0 = *x; // Load
//ymm4 = *y; // Load
//ymm8 = ymm0 * ymm4; // Mul
//ymm12 = ymm12 + ymm8; // Add
//x+=4;
//y+=4;
__asm__ __volatile__(
"\n\t"
"vmovapd 0*8(%[x]), %%ymm0 \n\t"
"vmovapd 0*8(%[y]), %%ymm4 \n\t"
"vmulpd %%ymm0 , %%ymm4 , %%ymm8 \n\t"
"vaddpd %%ymm8 , %%ymm12, %%ymm12\n\t"
"\n\t"
"subq $-4*8, %[x]\n\t"
"subq $-4*8, %[y]\n\t"
"\n\t"
:[x]"=r"(x),[y]"=r"(y)
:"0"(x),"1"(y)
);
}
if( n & 2 ){
//ymm0 = *x; // Load
//ymm4 = *y; // Load
//ymm8 = ymm0 * ymm4; // Mul
//ymm12 = ymm12 + ymm8; // Add
//x+=2;
//y+=2;
__asm__ __volatile__(
"\n\t"
"movapd 0*8(%[x]), %%xmm0 \n\t"
"movapd 0*8(%[y]), %%xmm4 \n\t"
"mulpd %%xmm0 , %%xmm4 \n\t"
"addpd %%xmm4 , %%xmm12\n\t"
"\n\t"
"subq $-2*8, %[x]\n\t"
"subq $-2*8, %[y]\n\t"
"\n\t"
:[x]"=r"(x),[y]"=r"(y)
:"0"(x),"1"(y)
);
}
if( n & 1 ){
//ymm0 = *x; // Load
//ymm4 = *y; // Load
//ymm8 = ymm0 * ymm4; // Mul
//ymm12 = ymm12 + ymm8; // Add
//x++;
//y++;
__asm__ __volatile__(
"\n\t"
"movsd 0*8(%[x]), %%xmm0 \n\t"
"movsd 0*8(%[y]), %%xmm4 \n\t"
"mulsd %%xmm0 , %%xmm4 \n\t"
"addsd %%xmm4 , %%xmm12\n\t"
"\n\t"
"subq $-1*8, %[x]\n\t"
"subq $-1*8, %[y]\n\t"
"\n\t"
:[x]"=r"(x),[y]"=r"(y)
:"0"(x),"1"(y)
);
}
//value = ymm12 + ymm13 + ymm14 + ymm15;
__asm__ __volatile__(
"\n\t"
"vaddpd %%ymm12, %%ymm13, %%ymm13\n\t"
"vaddpd %%ymm14, %%ymm15, %%ymm15\n\t"
"vaddpd %%ymm13, %%ymm15, %%ymm15\n\t"
"vperm2f128 $0x01, %%ymm15, %%ymm15, %%ymm14\n\t" // exchange |a|b|c|d| -> |c|d|a|b|
"vhaddpd %%ymm14, %%ymm15, %%ymm15\n\t"
"vhaddpd %%ymm15, %%ymm15, %%ymm15\n\t"
"movsd %%xmm15, %[v] \n\t"
"\n\t"
:[v]"=m"(value)
);
/********************************************************
ALIGNMENT CASE 1, 2, 3
********************************************************/
}else{
上記コードは、アライメント判定値が0の場合です。それ以外の場合に対しては、vmovapd命令とmovapd命令を、それぞれvmovupd命令とmovupd命令に置換したコードがこの後に続いています。
C言語のコードを、そのままアセンブラ命令に置き換えただけなので、ほとんど説明は不要でしょう。少し違う点としては、C言語ではベクトルx,yを1要素ずつロードしていたのに対して、vmovapd命令では4要素ずつロードしている点があります。また、16段アンロールにしたため、端数処理が増えています。
また、最後の全ての値を合計するコードは、ベクトル要素を全て足し合わせるために、置換命令vperm2f128と水平加算命令vhaddpdを駆使しています。
ループをずらす
このままだとC言語とほぼ変わりません。しかし、ロード命令→演算命令という順序は、RAW型のパイプラインハザードを起こしやすいので、逆順になるようにループをずらします。
n_unroll = (n>>4); // Unroll by 16 elements
__asm__ __volatile__(
"\n\t"
"vmovapd 0*8(%[x]), %%ymm0 \n\t"
"vmovapd 4*8(%[x]), %%ymm1 \n\t"
"vmovapd 8*8(%[x]), %%ymm2 \n\t"
"vmovapd 12*8(%[x]), %%ymm3 \n\t"
"vmovapd 0*8(%[y]), %%ymm4 \n\t"
"vmovapd 4*8(%[y]), %%ymm5 \n\t"
"vmovapd 8*8(%[y]), %%ymm6 \n\t"
"vmovapd 12*8(%[y]), %%ymm7 \n\t"
"\n\t"
"subq $-16*8, %[x]\n\t"
"subq $-16*8, %[y]\n\t"
"\n\t"
:[x]"=r"(x),[y]"=r"(y)
:"0"(x),"1"(y)
);
n_unroll--;
while( n_unroll-- ){
__asm__ __volatile__(
"\n\t"
"vmulpd %%ymm0 , %%ymm4 , %%ymm8 \n\t"
"vmulpd %%ymm1 , %%ymm5 , %%ymm9 \n\t"
"vmulpd %%ymm2 , %%ymm6 , %%ymm10\n\t"
"vmulpd %%ymm3 , %%ymm7 , %%ymm11\n\t"
"vaddpd %%ymm8 , %%ymm12, %%ymm12\n\t"
"vaddpd %%ymm9 , %%ymm13, %%ymm13\n\t"
"vaddpd %%ymm10, %%ymm14, %%ymm14\n\t"
"vaddpd %%ymm11, %%ymm15, %%ymm15\n\t"
"vmovapd 0*8(%[x]), %%ymm0 \n\t"
"vmovapd 4*8(%[x]), %%ymm1 \n\t"
"vmovapd 8*8(%[x]), %%ymm2 \n\t"
"vmovapd 12*8(%[x]), %%ymm3 \n\t"
"vmovapd 0*8(%[y]), %%ymm4 \n\t"
"vmovapd 4*8(%[y]), %%ymm5 \n\t"
"vmovapd 8*8(%[y]), %%ymm6 \n\t"
"vmovapd 12*8(%[y]), %%ymm7 \n\t"
"\n\t"
"subq $-16*8, %[x]\n\t"
"subq $-16*8, %[y]\n\t"
"\n\t"
:[x]"=r"(x),[y]"=r"(y)
:"0"(x),"1"(y)
);
}
__asm__ __volatile__(
"\n\t"
"vmulpd %%ymm0 , %%ymm4 , %%ymm8 \n\t"
"vmulpd %%ymm1 , %%ymm5 , %%ymm9 \n\t"
"vmulpd %%ymm2 , %%ymm6 , %%ymm10\n\t"
"vmulpd %%ymm3 , %%ymm7 , %%ymm11\n\t"
"vaddpd %%ymm8 , %%ymm12, %%ymm12\n\t"
"vaddpd %%ymm9 , %%ymm13, %%ymm13\n\t"
"vaddpd %%ymm10, %%ymm14, %%ymm14\n\t"
"vaddpd %%ymm11, %%ymm15, %%ymm15\n\t"
"\n\t"
::);
メインループ(n_unrollで回っているループ)について、1回目のロード命令を全てメインループの上に出し、最後の積算命令と加算命令をメインループの下に出しました。1回分の反復が減っているため、n_unrollを先に1回だけデクリメント(n_unroll--)しています。こうすることで、メインループの中では、演算命令とロード命令の順序が逆になりました。
端数処理は、多くとも1回しか計算されないので、変更していません。
アセンブラ命令を並び替える
前回確認したように、CPUの持つ演算器はいくつかの命令を同時に実行することができました。Sandy Bridgeの場合は、ロード命令2つ、乗算命令1つ、加算命令1つが同時に実行できました。
ということは、ロード命令が8個並んでいると、その間は乗算器も加算器も遊んでしまう(ストールしてしまう)ので、ロード命令2個と乗算、加算を並べた方が、より効率的になるはずです。
n_unroll = (n>>4); // Unroll by 16 elements
__asm__ __volatile__(
"\n\t"
"vmovapd 0*8(%[x]), %%ymm0 \n\t"
"vmovapd 4*8(%[x]), %%ymm1 \n\t"
"vmovapd 8*8(%[x]), %%ymm2 \n\t"
"vmovapd 12*8(%[x]), %%ymm3 \n\t"
"vmovapd 0*8(%[y]), %%ymm4 \n\t"
"vmovapd 4*8(%[y]), %%ymm5 \n\t"
"vmovapd 8*8(%[y]), %%ymm6 \n\t"
"vmovapd 12*8(%[y]), %%ymm7 \n\t"
"\n\t"
"subq $-16*8, %[x]\n\t"
"subq $-16*8, %[y]\n\t"
"\n\t"
:[x]"=r"(x),[y]"=r"(y)
:"0"(x),"1"(y)
);
n_unroll--;
while( n_unroll-- ){
__asm__ __volatile__(
"\n\t"
"vmulpd %%ymm0 , %%ymm4 , %%ymm8 \n\t"
"vmovapd 0*8(%[x]), %%ymm0 \n\t"
"vmovapd 0*8(%[y]), %%ymm4 \n\t"
"vaddpd %%ymm8 , %%ymm12, %%ymm12\n\t"
"vmulpd %%ymm1 , %%ymm5 , %%ymm9 \n\t"
"vmovapd 4*8(%[x]), %%ymm1 \n\t"
"vmovapd 4*8(%[y]), %%ymm5 \n\t"
"vaddpd %%ymm9 , %%ymm13, %%ymm13\n\t"
"vmulpd %%ymm2 , %%ymm6 , %%ymm10\n\t"
"vmovapd 8*8(%[x]), %%ymm2 \n\t"
"vmovapd 8*8(%[y]), %%ymm6 \n\t"
"vaddpd %%ymm10, %%ymm14, %%ymm14\n\t"
"vmulpd %%ymm3 , %%ymm7 , %%ymm11\n\t"
"vmovapd 12*8(%[x]), %%ymm3 \n\t"
"vmovapd 12*8(%[y]), %%ymm7 \n\t"
"vaddpd %%ymm11, %%ymm15, %%ymm15\n\t"
"\n\t"
"subq $-16*8, %[x]\n\t"
"subq $-16*8, %[y]\n\t"
"\n\t"
:[x]"=r"(x),[y]"=r"(y)
:"0"(x),"1"(y)
);
}
__asm__ __volatile__(
"\n\t"
"vmulpd %%ymm0 , %%ymm4 , %%ymm8 \n\t"
"vmulpd %%ymm1 , %%ymm5 , %%ymm9 \n\t"
"vmulpd %%ymm2 , %%ymm6 , %%ymm10\n\t"
"vmulpd %%ymm3 , %%ymm7 , %%ymm11\n\t"
"vaddpd %%ymm8 , %%ymm12, %%ymm12\n\t"
"vaddpd %%ymm9 , %%ymm13, %%ymm13\n\t"
"vaddpd %%ymm10, %%ymm14, %%ymm14\n\t"
"vaddpd %%ymm11, %%ymm15, %%ymm15\n\t"
"\n\t"
::);
ロード命令vmovapdを、計算結果が変わらない範囲で、可能な限り上に移動させました。
最終測定してみる
では、ベースコード(アンロール済み)と比較して、速くなったのか確かめてみましょう。
ベクトルの要素数が大きくなってしまうと、あまり変わりませんね。4095要素から65536要素くらいだと、2〜3倍のスピードが出るようです。
まとめ
内積計算を高速化しました。結果として、演算の時間はほとんど隠れてしまい、メモリアクセスに依存していました。
サンプルコードをGitHubで公開しておきます。gitのログを遡れば、修正途中のコードもみられます。