見出し画像

pythonで連続値の最頻値を手っ取り早く取得

こんにちは。株式会社Rosso、AI部です。
今回のテーマは「pythonで連続値の最頻値を手っ取り早く取得する方法」です。短めにまとめてますので、ぜひ最後までお読みください。

はじめに

連続値から密度関数を作るにはカーネル密度推定をすればよいですが、
パパっとカーネル密度推定をする方法が意外とないようです。

今回は、代わりにpandasのDataFrameのplotメソッド、またはsns.kdeplotでグラフ描画の際に取得できるデータを利用する方法を紹介します。

連続値の最頻値の導出

データ作成

最頻値の導出のテストのために
平均値・中央値・最頻値がそれぞれ異なるガンマ分布によって値を生成。

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

arr = np.random.gamma(1, 0.1, 1000)

"""
array([2.87936854e-02, 3.70731463e-02, 1.77116636e-01, 1.69945689e-01,
       4.65948818e-02, 1.40371776e-01, 2.06753746e-02, 4.83618765e-01,
       6.82041784e-02, 1.03181197e-01, 1.99465678e-02, 7.43770849e-02,
       5.07223702e-02, 2.06828856e-02, 2.83959190e-01, 3.46902340e-02,
...
       1.86675666e-02, 1.18242077e-01, 2.07714469e-01, 1.28203455e-01])
"""


グラフ出力のデータを取得

以下のようにグラフ出力を変数に格納すればget_linesメソッドなどでグラフ描画のx,y座標の情報が取得できる。

res = pd.Series(arr).plot(kind="kde")
xydata = res.get_lines()[0].get_xydata()
plt.close()
"""
array([[-4.15598706e-01,  1.06875007e-51],
       [-4.13933748e-01,  2.71265008e-51],
       [-4.12268790e-01,  6.85964561e-51],
       ...,
       [ 1.24436424e+00,  3.53726536e-52],
       [ 1.24602920e+00,  1.40386191e-52],
       [ 1.24769415e+00,  5.55091339e-53]])
"""
このようなグラフが本来は表示されるが
欲しいのは描画のx,y座標なのでplt.close()で消している

y座標が最も高くなるx座標を取得してしまえばそれが最頻値となる。

x = xydata[:,0]
y = xydata[:,1]

#最頻値
mode = x[y.argmax()]

確認

plt.plot(x,y)
plt.axvline(mode,label="最頻値",c="C1")

plt.axvline(np.median(arr),label="中央値",c="C2",ls="--",alpha=0.5)
plt.axvline(arr.mean(),label="平均値",c="C3",ls="--",alpha=0.5)
plt.legend()
最頻値を取得することができた。

最頻値の算出方法まとめ

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
#import seaborn as sns


def get_mode(arr):
    res = pd.Series(arr).plot(kind="kde")
	
    #seabornでも可能
    #res = sns.kdeplot(arr)
	
    plt.close()
    xydata = res.get_lines()[0].get_xydata()
	
    x = xydata[:,0]
    y = xydata[:,1]
    mode = x[y.argmax()]
    return mode

注意:逆に離散変数で行うと想定外の結果に

#二項分布で整数のみの乱数生成
arr = np.random.binomial(n=100,p=0.5,size=1000)
modeA = get_mode(arr)

series = pd.Series(arr)

modeB = series.mode()[0]

fig = plt.figure()
ax = fig.add_subplot(111)

sns.histplot(series,kde=True,bins=list(range(arr.min(),arr.max())))

plt.axvline(modeA,label="カーネル密度推定からの最頻値",c="C1")
plt.axvline(modeB,label="離散値の最頻値",c="C2")

plt.legend(loc ="upper left",bbox_to_anchor=(1,1))
離散変数でグラフから最頻値を導出するとズレます。

カーネル密度推定を使った最頻値では51.3168…
離散値の最頻値では49
と、タスクによっては想定と異なる結果になる場合があるので注意が必要。

参考

最後までお読みいただき、ありがとうございました!

この記事が参加している募集