Python高速化メモ(Numba&並列化)
Pythonはプログラミングが容易でライブラリも豊富なので便利なのだけれども、唯一の泣き所はNumpyなどのコンパイルされたライブラリ以外の実行がとても遅いところだね。そこで任意の関数をJIT(Just-In-Time)コンパイルを行ってもらえるNumbaの用いたときの実行速度を計測してみるね。
サンプルプログラム:素数の数え上げ
#素数判定関数
def checkPrimeNumber( n ):
if( n <= 1 ): return False
if( n == 2 ): return True
for m in range( 2, n ):
if(n % m == 0): return False
return True
#素数の数を数える
def counting( N ):
nn = 0
for n in range( N ):
if( checkPrimeNumber( n ) ): nn += 1
return nn
counting関数は引数で指定した整数までの素数の数を数え上げるね。この関数を用いて200000までの素数を数え上げて、計算終了までの時間を計測するね。
############################################
# Python利用
start = time.time()
result = counting( 200000 )
print(result)
end = time.time()
time1 = end - start
print("Python:", time1, "[s]")
############################################
比較として、Numbaを導入してJIT化した同様の関数を次のように定義するよ。
#素数判定関数
@jit( bool( int32 ), nopython=True )
def jit_checkPrimeNumber( n ):
if( n <= 1 ): return False
if( n == 2 ): return True
for m in range( 2, n ):
if(n % m == 0): return False
return True
#素数の数を数える
@jit( int32( int32 ), nopython=True )
def jit_counting( N ):
nn = 1 #2
for n in range(3, N, 2):
if( jit_checkPrimeNumber( n ) ): nn += 1
return nn
同様にこの関数を用いて200000までの素数を数え上げて、計算終了までの時間を計測するね。
############################################
# Numba利用
start = time.time()
result = jit_counting( 200000 )
print(result)
end = time.time()
time2 = end-start
print("Numba:", time2, "[s]")
#############################################
さらに、jit_counting関数のforループを並列化した関数を次のように準備するよ。
#素数の数を数える
@jit( int32( int32 ), nopython=True, parallel=True )
def p_jit_counting( N ):
nn = 0
set_num_threads(NUMBER_OF_THREADS)
for n in prange( N ):
if( jit_checkPrimeNumber( n ) ): nn += 1
return nn
同様にこの関数を用いて200000までの素数を数え上げて、計算終了までの時間を計測するね。
#############################################
# Numba利用(並列化)
start = time.time()
result = p_jit_counting( 200000 )
print(result)
end = time.time()
time3 = end-start
print("Numba Parallel:", time3, "[s]")
############################################
計測結果
Python: 56.49 [s]
Numba: 2.390 [s]
Numba Parallel: 0.4291 [s]
なんと、Pythonを基準として、
Numba: 23.64倍
Numba Parallel : 131.6倍
で、並列化(並列数16)で5倍になったよ。
こんな簡単なプログラムで130倍も高速化ができたね。今回のような素数の数え上げは、並列の分割が上記のように単純な場合にはスレッドごとの負荷が全然異なるため、並列数16に対して5倍程度になってしまうね。計算負荷を均一化すれば原理的には16倍近くまで速くなりようだね。
Python プログラムソース
プログラムソース全文を以下に記します。これからも応援よろしくお願いしまーす。
import time
from numba import jit, prange, set_num_threads, uint8, int32, int64, float64, complex128, void, bool
NUMBER_OF_THREADS = 16
#素数判定関数
def checkPrimeNumber( n ):
if( n <= 1 ): return False
if( n == 2 ): return True
for m in range( 2, n ):
if(n % m == 0): return False
return True
#素数の数を数える
def counting( N ):
nn = 0
for n in range( N ):
if( checkPrimeNumber( n ) ): nn += 1
return nn
#素数判定関数
@jit( bool( int32 ), nopython=True )
def jit_checkPrimeNumber( n ):
if( n <= 1 ): return False
if( n == 2 ): return True
for m in range( 2, n ):
if(n % m == 0): return False
return True
#素数の数を数える
@jit( int32( int32 ), nopython=True )
def jit_counting( N ):
nn = 1 #2
for n in range(3, N, 2):
if( jit_checkPrimeNumber( n ) ): nn += 1
return nn
#素数の数を数える
@jit( int32( int32 ), nopython=True, parallel=True )
def p_jit_counting( N ):
nn = 0
set_num_threads(NUMBER_OF_THREADS)
for n in prange( N ):
if( jit_checkPrimeNumber( n ) ): nn += 1
return nn
'''
@jit( int32( int32 ), nopython=True, parallel=True )
def p_jit_counting( N ):
nn = 0
set_num_threads(num_threads)
for thread_id in prange(num_threads):
for n in range( thread_id, N, num_threads):
if( jit_checkPrimeNumber( n ) ): nn += 1
return nn
'''
############################################
# Python利用
start = time.time()
result = counting( 200000 )
print(result)
end = time.time()
time1 = end - start
print("Python:", time1, "[s]")
############################################
# Numba利用
start = time.time()
result = jit_counting( 200000 )
print(result)
end = time.time()
time2 = end-start
print("Numba:", time2, "[s]")
#############################################
# Numba利用(並列化)
start = time.time()
result = p_jit_counting( 200000 )
print(result)
end = time.time()
time3 = end-start
print("Numba Parallel:", time3, "[s]")
############################################
#速度倍率
print ( "Python / Numba:", time1/time2 )
print ( "Python / Numba Parallel:", time1/time3 )
print ( "Numba / Numba Parallel:", time2/time3 )