簡単な行列積計算プログラム|行列積高速化#4
この記事は、以下の記事を分割したものです。
[元の記事]行列積計算を高速化してみる
一括で読みたい場合は、元の記事をご覧ください。
チューニングに必要な準備ができたので、行列積計算プログラムを作成していきます。まず、BLASの行列積ルーチンDGEMMの公式を再確認しておきましょう。
C = alpha * A * B + beta * C
ここで、A, B, Cは行列、alpha, betaはスカラー係数を表しています。
関数インターフェースは、上記のテストプログラムで使用できるように、CBLASのcblas_dgemm関数と同じものとし、関数名をmyblas_dgemmとします。
void myblas_dgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const double alpha, const double *A,
const int lda, const double *B, const int ldb,
const double beta, double *C, const int ldc);
myblas_dgemmでは、第一引数Orderの処理と引数のエラー処理だけを行い、Orderの違いを吸収してしまいます。第一引数Orderの処理は、CblasRowMajorが選ばれた場合に、行列Aと行列Bを入れ替えるという処理をします。これは、cblas_dgemm関数の実装を参照にしています。
実際の計算には、myblas_dgemm_mainという別の関数を用意しています。
void myblas_dgemm_main( gemm_args_t* args );
引数はgemm_args_t型の構造体にまとめることにしました。これは、高速化作業で引数を追加する必要が出てきた時、引数の変更作業を楽にするためです。構造体の定義は、下記の通りです。
typedef struct _gemm_args_t {
size_t TransA;
size_t TransB;
size_t M;
size_t N;
size_t K;
double alpha;
const double *A;
size_t lda;
const double *B;
size_t ldb;
double beta;
double *C;
size_t ldc;
} gemm_args_t;
転置設定のTransAとTransBは、元々enum CBLAS_TRANSPOSE型でしたが、今後のためを考えてビットフラグに変更しました。フラグは、転置を1ビット目、複素共役を2ビット目とし、下記のようなマスクを用意しています。
#define MASK_TRANS 0x01
#define MASK_CONJ 0x02
4-1. myblas_dgemm関数
行列Aと行列Bの入れ替えが必要なCblasRowMajorの場合の引数処理及びエラー処理は、下記のようになります。CblasColMajorの場合は、TransA, TransBの処理やM,Nの入れ替えを元に戻したコードを通るように条件分岐しています。
gemm_args_t args={0,0,0,0,0,0e0,NULL,0,NULL,0,0e0,NULL,0};
int info = 0;
if( Order == CblasColMajor ){
/** 省略 **/
}else if( Order == CblasRowMajor ){
// Transpose Set-up
if( TransA == CblasNoTrans ){ args.TransB = 0; }
if( TransA == CblasTrans ){ args.TransB = MASK_TRANS; }
if( TransA == CblasConjTrans){ args.TransB = MASK_TRANS | MASK_CONJ; }
if( TransB == CblasNoTrans ){ args.TransA = 0; }
if( TransB == CblasTrans ){ args.TransA = MASK_TRANS; }
if( TransB == CblasConjTrans){ args.TransA = MASK_TRANS | MASK_CONJ; }
// Error Check
if( C == NULL ) info=13;
if( B == NULL ) info=10;
if( A == NULL ) info= 8;
if( info ){ myblas_xerbla("myblas_dgemm",info); return; }
int ma = ( (args.TransB & MASK_TRANS) ? M : K );
int mb = ( (args.TransA & MASK_TRANS) ? K : N );
if( ldc < N ) info=14;
if( ldb < mb ) info=11;
if( lda < ma ) info= 9;
if( K < 0 ) info= 6;
if( N < 0 ) info= 5;
if( M < 0 ) info= 4;
if( info ){ myblas_xerbla("myblas_dgemm",info); return; }
// No Computing
if( M == 0 ) return;
if( N == 0 ) return;
if( K == 0 && beta == 1e0 ) return;
if( alpha== 0e0 && beta == 1e0 ) return;
// Set arguments
args.M = N;
args.N = M;
args.K = K;
args.alpha = alpha;
args.A = B;
args.lda = ldb;
args.B = A;
args.ldb = lda;
args.beta = beta;
args.C = C;
args.ldc = ldc;
}
myblas_dgemm_main( &args );
ここで、myblas_xerbla関数は、BLASのエラー出力ルーチンXERBLAを模したものです。下記のような、BLASのエラーメッセージを出力します。
void myblas_xerbla( const char* name, int info ){
printf(" ** On entry to %s parameter number %d had an illegal value\n",name,info);
}
4-2. myblas_dgemm_main関数
実際の行列積計算は、行列Aと行列Bの転置有無の組み合わせで、全部で4パターンのコードが必要になります。転置なしの場合は、次のようなコードになります。
for( size_t j=0; j<N; j++ ){
for( size_t i=0; i<M; i++ ){
AB=0e0;
for( size_t k=0; k<K; k++ ){
AB = AB + (*A)*(*B);
A += lda;
B++;
}
*C=beta*(*C) + alpha*AB;
A = A - lda*K + 1;
B = B - K;
C++;
}
A = A - M;
B = B + ldb;
C = C - M + ldc;
}
行列CがM×N行列(Mがメモリ連続方向)で、Kに依存しないので、無駄なメモリアクセスを避けるために、外側からN→M→Kの順の三重ループにしています。本来、CはM*N*K回のメモリアクセスが必要ですが、この順にすると、M*N回のメモリアクセスで済みます。また、例えばM→N→Kの順にすると、Cはストライドアクセスが必要になります。これを避けるため、N→M→Kの順にしています。
一方、行列AはM×K行列なので、Kループはメモリ連続方向ではありません。このため、ldaずつ飛び飛びのストライドアクセスになってしまいます。逆に、行列BはK×N行列のため、Kループがメモリ連続方向なので連続アクセスになります。行列Aがどうしてもストライドアクセスになるため、スピードがほとんど出ないと予想されます。
実際、計算速度を測定してみると、次のようになりました。
Max Peak MFlops per Core: 52800 MFlops
Base Peak MFlops per Core: 46400 MFlops
size , elapsed time[s], MFlops, base ratio[%], max ratio[%]
16, 4.05312E-06, 2210.64, 4.76432, 4.18683
32, 3.40939E-05, 2012.33, 4.33691, 3.81123
64, 0.000264883, 2025.71, 4.36575, 3.83657
128, 0.00234389, 1810.43, 3.90179, 3.42885
256, 0.0287192, 1175.21, 2.53278, 2.22577
512, 0.263068, 1023.39, 2.20559, 1.93824
1024, 6.27475, 342.743, 0.738671, 0.649135
2048, 113.116, 151.99, 0.327564, 0.28786
M=N=K=2048で、基本理論ピーク性能の0.3%しか出ていません。どうしようもなく、遅いですね…。
ちなみに、4パターン内で最も高速なのは、行列Aだけを転置した場合で、実装コードは次のようになります。
for( size_t j=0; j<N; j++ ){
for( size_t i=0; i<M; i++ ){
AB=0e0;
for( size_t k=0; k<K; k++ ){
AB = AB + (*A)*(*B);
A++;
B++;
}
*C=beta*(*C) + alpha*AB;
A = A - K + lda;
B = B - K;
C++;
}
A = A - lda*M;
B = B + ldb;
C = C - M + ldc;
}
行列Aを転置した場合は、最内のKループにおいて、AもBも連続アクセスになっています。この時の、計算速度は次のようになります。
Max Peak MFlops per Core: 52800 MFlops
Base Peak MFlops per Core: 46400 MFlops
size , elapsed time[s], MFlops, base ratio[%], max ratio[%]
16, 4.05312E-06, 2210.64, 4.76432, 4.18683
32, 3.00407E-05, 2283.83, 4.92205, 4.32544
64, 0.000258923, 2072.34, 4.46625, 3.92489
128, 0.0022769, 1863.7, 4.0166, 3.52974
256, 0.0209579, 1610.42, 3.47073, 3.05003
512, 0.178289, 1510.03, 3.25438, 2.85991
1024, 1.40407, 1531.71, 3.3011, 2.90097
2048, 11.1541, 1541.35, 3.32188, 2.91923
この場合は、基本周波数の理論ピーク性能比で3.3%以上のスピードが出ます。
ということで、ストライドアクセスだととても遅いことがわかりますね。
4-3. 初期プログラム
最後に、初期プログラムの全体を載せておきます。
void myblas_dgemm_main( gemm_args_t* args ){
size_t TransA = args->TransA;
size_t TransB = args->TransB;
size_t M = args->M;
size_t N = args->N;
size_t K = args->K;
double alpha = args->alpha;
const double *A = args->A;
size_t lda = args->lda;
const double *B = args->B;
size_t ldb = args->ldb;
double beta = args->beta;
double *C = args->C;
size_t ldc = args->ldc;
double AB;
if( TransA & MASK_TRANS ){
if( TransB & MASK_TRANS ){
for( size_t j=0; j<N; j++ ){
for( size_t i=0; i<M; i++ ){
AB=0e0;
for( size_t k=0; k<K; k++ ){
AB = AB + (*A)*(*B);
A++;
B+=ldb;
}
*C=beta*(*C) + alpha*AB;
A = A - K + lda;
B = B - ldb*K;
C++;
}
A = A - lda*M;
B = B + 1;
C = C - M + ldc;
}
}else{
for( size_t j=0; j<N; j++ ){
for( size_t i=0; i<M; i++ ){
AB=0e0;
for( size_t k=0; k<K; k++ ){
AB = AB + (*A)*(*B);
A++;
B++;
}
*C=beta*(*C) + alpha*AB;
A = A - K + lda;
B = B - K;
C++;
}
A = A - lda*M;
B = B + ldb;
C = C - M + ldc;
}
}
}else{
if( TransB & MASK_TRANS ){
for( size_t j=0; j<N; j++ ){
for( size_t i=0; i<M; i++ ){
AB=0e0;
for( size_t k=0; k<K; k++ ){
AB = AB + (*A)*(*B);
A += lda;
B += ldb;
}
*C=beta*(*C) + alpha*AB;
A = A - lda*K + 1;
B = B - ldb*K;
C++;
}
A = A - M;
B = B + 1;
C = C - M + ldc;
}
}else{
for( size_t j=0; j<N; j++ ){
for( size_t i=0; i<M; i++ ){
AB=0e0;
for( size_t k=0; k<K; k++ ){
AB = AB + (*A)*(*B);
A += lda;
B++;
}
*C=beta*(*C) + alpha*AB;
A = A - lda*K + 1;
B = B - K;
C++;
}
A = A - M;
B = B + ldb;
C = C - M + ldc;
}
}
}
}
これで、ようやく準備が整いました。以降では、高速化の方法を実際にプログラムを書きながら解説していこうと思います。
次の記事
元の記事はこちらです。
ソースコードはGitHubで公開しています。