Python3でnumpy.lib.stride_tricks.as_strided()とちょっと仲良くなる話。

1. numpy.lib.stride_tricks.as_strided()とは何者なの?

numpy.lib.stride_tricks.as_strided()の主な引数は(arr, shape, strides)の3つになっています。この関数の説明をするためには、まずはshapeとstridesについて解説する必要があるでしょう。
とりあえずnumpyをimportして配列を作ってからいろいろ見ていきましょう。

import numpy as np

A = np.arange(10, dtype='float64')

# A = array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])
# A.shape = (10,)
# A.strides = (8,)
# A.itemsize = 8

shapeのデフォルトはA.shapeの値となっています。そのままだと形状の変換を行わないということです。ここでx*y=10となるようにshape=(x,y)と設定し、stridesをデフォルトにしておけばA.reshape(x,y)と同じ働きになります。

B1 = np.lib.stride_tricks.as_strided(A, shape=(10,))
# B1 = array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])

B2 = np.lib.stride_tricks.as_strided(A, shape=(5, 2))
# B2 = array([[0., 1.],
#            [2., 3.],
#            [4., 5.],
#            [6., 7.],
#            [8., 9.]])

B3 = np.lib.stride_tricks.as_strided(A, shape=(3,))
# B3 = array([0., 1., 2.])

B4 = np.lib.stride_tricks.as_strided(A, shape=(4, 3))
# B4 = array([[ 0.00000000e+000,  1.00000000e+000,  2.00000000e+000],
#            [ 3.00000000e+000,  4.00000000e+000,  5.00000000e+000],
#            [ 6.00000000e+000,  7.00000000e+000,  8.00000000e+000],
#            [ 9.00000000e+000,  2.56765117e-312, -2.39841233e-212]])

注意点として、B4のようにstridesがデフォルトの状態でshapeをAの範囲を超える値に設定すると、意味のない値を吐き出すようになります。ゼロで埋めてくれそうに見えてそんなことはないんですね。
次にstridesを設定してみましょう。stridesとは、ざっくり言うと値を参照する間隔のことです。例えば、Aから1つ飛ばしで要素を参照するなら、stridesに設定すべき値は2*itemsizeです。ここでitemsizeは要素のバイト数を表していて、A.itemsizeで参照できる値です。今回は8ですね。では、実際にやってみましょう。

C = np.lib.stride_tricks.as_strided(A, shape=(5,), strides=(16,))
# C = array([0., 2., 4., 6., 8.])

2. で、何の役に立つの?

これだけでは何のありがたみもありませんが、例えばこれを応用すれば次のように畳み込みが行えたりします。わかりやすいようにAの値を少し変えてみましょう。

A = np.array([2, 5, 1, 8, 3, 7, 0, 4, 6, 9], dtype='float64')

D = np.lib.stride_tricks.as_strided(A, shape=(9,2), strides=(8,8))
E = np.array([0.2, 0.1])
# D = array([[2., 5.],
#            [5., 1.],
#            [1., 8.],
#            [8., 3.],
#            [3., 7.],
#            [7., 0.],
#            [0., 4.],
#            [4., 6.],
#            [6., 9.]])

F = np.sum(D*E, axis=1)
# F = array([0.9, 1.1, 1. , 1.9, 1.3, 1.4, 0.4, 1.4, 2.1])

といった具合です。3個ずつ取り出して平均を出すような操作も簡単にできます。

D2 = np.lib.stride_tricks.as_strided(A, shape=(8,3), strides=(8,8))
# D2 = array([[2., 5., 1.],
#             [5., 1., 8.],
#             [1., 8., 3.],
#             [8., 3., 7.],
#             [3., 7., 0.],
#             [7., 0., 4.],
#             [0., 4., 6.],
#             [4., 6., 9.]])

G = np.sum(D2, axis=1)/3
# G = array([2.66666667, 4.66666667, 4.        , 6.        , 3.33333333,
#            3.66666667, 3.33333333, 6.33333333])

np.lib.stride_tricks.as_strided()は配列要素をシフトしながら全て参照するという動作が出来るので、N個から2個取り出すような順列なら容易に実装できます。試しに5個から2個取り出してみましょう。

A = np.arange(5, dtype='float64')
B = np.concatenate((A, A[:-1]))
# B = array([0., 1., 2., 3., 4., 0., 1., 2., 3.])

I1 = np.lib.stride_tricks.as_strided(B[1:], shape=(4, 5), strides=(8, 8)).reshape(-1, 1)
# I1 = array([[1.], [2.], [3.], [4.], [0.], [2.], [3.], [4.], [0.], [1.],
#             [3.], [4.], [0.], [1.], [2.], [4.], [0.], [1.], [2.], [3.]])

I2 = np.broadcast_to(A[:,None], (5, 4)).T.reshape(-1, 1)
# I2 = array([[0.], [1.], [2.], [3.], [4.], [0.], [1.], [2.], [3.], [4.],
#             [0.], [1.], [2.], [3.], [4.], [0.], [1.], [2.], [3.], [4.]])

J = np.concatenate((I2, I1), axis=1)
# J = array([[0., 1.], [1., 2.], [2., 3.], [3., 4.], [4., 0.],
#            [0., 2.], [1., 3.], [2., 4.], [3., 0.], [4., 1.],
#            [0., 3.], [1., 4.], [2., 0.], [3., 1.], [4., 2.],
#            [0., 4.], [1., 0.], [2., 1.], [3., 2.], [4., 3.]])

といった感じです。

次回は画像をセグメンテーションして領域の周の長さを求める話とかをするような気がします。しないかもしれません。

いいなと思ったら応援しよう!