pythonで連続値の最頻値を手っ取り早く取得
こんにちは。株式会社Rosso、AI部です。
今回のテーマは「pythonで連続値の最頻値を手っ取り早く取得する方法」です。短めにまとめてますので、ぜひ最後までお読みください。
はじめに
連続値から密度関数を作るにはカーネル密度推定をすればよいですが、
パパっとカーネル密度推定をする方法が意外とないようです。
今回は、代わりにpandasのDataFrameのplotメソッド、またはsns.kdeplotでグラフ描画の際に取得できるデータを利用する方法を紹介します。
連続値の最頻値の導出
データ作成
最頻値の導出のテストのために
平均値・中央値・最頻値がそれぞれ異なるガンマ分布によって値を生成。
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
arr = np.random.gamma(1, 0.1, 1000)
"""
array([2.87936854e-02, 3.70731463e-02, 1.77116636e-01, 1.69945689e-01,
4.65948818e-02, 1.40371776e-01, 2.06753746e-02, 4.83618765e-01,
6.82041784e-02, 1.03181197e-01, 1.99465678e-02, 7.43770849e-02,
5.07223702e-02, 2.06828856e-02, 2.83959190e-01, 3.46902340e-02,
...
1.86675666e-02, 1.18242077e-01, 2.07714469e-01, 1.28203455e-01])
"""
グラフ出力のデータを取得
以下のようにグラフ出力を変数に格納すればget_linesメソッドなどでグラフ描画のx,y座標の情報が取得できる。
res = pd.Series(arr).plot(kind="kde")
xydata = res.get_lines()[0].get_xydata()
plt.close()
"""
array([[-4.15598706e-01, 1.06875007e-51],
[-4.13933748e-01, 2.71265008e-51],
[-4.12268790e-01, 6.85964561e-51],
...,
[ 1.24436424e+00, 3.53726536e-52],
[ 1.24602920e+00, 1.40386191e-52],
[ 1.24769415e+00, 5.55091339e-53]])
"""
y座標が最も高くなるx座標を取得してしまえばそれが最頻値となる。
x = xydata[:,0]
y = xydata[:,1]
#最頻値
mode = x[y.argmax()]
確認
plt.plot(x,y)
plt.axvline(mode,label="最頻値",c="C1")
plt.axvline(np.median(arr),label="中央値",c="C2",ls="--",alpha=0.5)
plt.axvline(arr.mean(),label="平均値",c="C3",ls="--",alpha=0.5)
plt.legend()
最頻値の算出方法まとめ
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
#import seaborn as sns
def get_mode(arr):
res = pd.Series(arr).plot(kind="kde")
#seabornでも可能
#res = sns.kdeplot(arr)
plt.close()
xydata = res.get_lines()[0].get_xydata()
x = xydata[:,0]
y = xydata[:,1]
mode = x[y.argmax()]
return mode
注意:逆に離散変数で行うと想定外の結果に
#二項分布で整数のみの乱数生成
arr = np.random.binomial(n=100,p=0.5,size=1000)
modeA = get_mode(arr)
series = pd.Series(arr)
modeB = series.mode()[0]
fig = plt.figure()
ax = fig.add_subplot(111)
sns.histplot(series,kde=True,bins=list(range(arr.min(),arr.max())))
plt.axvline(modeA,label="カーネル密度推定からの最頻値",c="C1")
plt.axvline(modeB,label="離散値の最頻値",c="C2")
plt.legend(loc ="upper left",bbox_to_anchor=(1,1))
カーネル密度推定を使った最頻値では51.3168…
離散値の最頻値では49
と、タスクによっては想定と異なる結果になる場合があるので注意が必要。
参考
最頻値の求め方は0-1型単純損失におけるf(a)の最大化。
How to extract data from a pandas plot?
matplotlibの結果からプロットデータを取得する方法について参考に。
カーネル密度推定(KDE: Kernel Density Estimation)をpython (numpy)で実装・整理してみる
ちゃんとカーネル密度推定したい場合に参考になります。
やり方によって最頻値が複数ある場合に1つしか算出されなかったりします。
最後までお読みいただき、ありがとうございました!