見出し画像

BigQueryのユーザー定義集約関数(UDAF)を使った線形回帰係数と予測値の計算方法

こんにちは、マネーフォワードケッサイ開発本部の tamiya です。

この記事では、BigQuery のユーザー定義集約関数(User Defined Aggregate Functions, UDAF)という機能(※2024年7月現在、Pre-GA 提供中)を使って、線形回帰係数や予測値を計算する方法について解説します。

ことの始まり

例えば、下記のような時系列データがあったとします。
日々の売り上げやユーザー数のようなデータだと思ってください。

この時、データの傾きから大まかな上昇・下降トレンドを評価したいと思います。
そのために、線形回帰(単回帰)で直線にフィットさせ、その傾き(回帰係数)を得ることを考えます。

さて、SQL では合計値 (SUM) や平均値 (AVG) 、最大・最小値 (MAX, MIN) といった集計は関数一つで行えますが、データの傾きのようなものは一発では計算できません。

しかし、BigQuery が提供している関数を組み合わせれば線形回帰係数を計算することができますし、30日後の予測値なども求めることができます。

そこでこの記事では BigQuery の関数を用いて線形回帰を行う方法を解説したうえで、処理を何度も使いまわせるようにユーザー定義集約関数(UDAF)にしてまとめる方法についても紹介します。

BigQuery で線形回帰係数を計算してみる

まずは、以下のようなデータを例に BigQuery の関数を用いて線形回帰を行う方法について解説します:

WITH sample_data AS (
  SELECT 1 AS x, 2 AS y UNION ALL
  SELECT 2 AS x, 5 AS y UNION ALL
  SELECT 3 AS x, 3 AS y UNION ALL
  SELECT 4 AS x, 9 AS y UNION ALL
  SELECT 5 AS x, 11 AS y
)
SELECT * FROM sample_data

これをグラフにすると、以下の黒線のようになります:

青の点線は、線形回帰により引いた直線です(上記の図では R を用いて計算)。

この線形回帰直線は、以下のように表されます:

$$
y = \hat{a} x + \hat{b}
$$

$${\hat{a}}$$ が回帰直線の傾きを表す回帰係数です。

$${N}$$ 個のデータ点 $${\{ (x_1, y_1), (x_2, y_2), …, (x_N, y_N))\} }$$ が与えられた時、この回帰係数(の推定量)$${\hat{a}}$$ は以下の形で表すことができます:

$$
\hat{a} = \frac{\sum_{i=1}^N (x_i - \bar{x})(y_i - \bar{y})}{\sum_{i=1}^N (x_i - \bar{x})^2}
$$

$${\bar{x}, \bar{y}}$$ はそれぞれ標本平均 $${\bar{x} = \frac{1}{N}\sum_{i=1}^N x_i}$$, $${\bar{y} = \frac{1}{N}\sum_{i=1}^N y_i}$$ です。
式の導出などの詳細は統計学や計量経済学の教科書を参照してください。

ところで、$${x}$$ の標本分散と、$${x, y}$$ の標本共分散はそれぞれ以下の式で表されます:

$$
\frac{1}{N-1}\sum_{i=1}^N (x_i - \bar{x})^2
$$

$$
\frac{1}{N-1}\sum_{i=1}^N (x_i - \bar{x})(y_i - \bar{y})
$$

したがって、前述の回帰係数推定量は、$${x,y}$$ の標本共分散を $${x}$$ の標本分散で割ったものとして求めることができます。

BigQuery では標本共分散は `COVAR_SAMP(X, Y)`、標本分散は `VAR_SAMP(X)` という関数で計算できます。
そこで、次のようなクエリを書けば BigQuery で回帰係数推定量を求めることができます:

SELECT
  SAFE_DIVIDE(COVAR_SAMP(x, y), VAR_SAMP(x)) AS linreg_coef
FROM
  sample_data

`SAFE_DIVIDE` というのは、分母が0の時に NULL を返す割り算関数です。
単純に割り算を計算するなら `COVAR_SAMP(x, y) / VAR_SAMP(x)` ですが、これでは $${x}$$ の標本分散が0のとき($${x}$$ が常に一定の値を持つ時など)にはエラーになり実行が止まってしまうので今回は `SAFE_DIVIDE` を使いました。

なお、BigQuery で分散・共分散を計算する関数には、上記で用いた `VAR_SAMP`, `COVAR_SAMP` のほかに、$${N-1}$$ の代わりに $${N}$$ で割った `VAR_POP`, `COVAR_POP` もあります。
しかし、今回の回帰係数推定量の計算ではどちらを用いても結果は同じになります。

ユーザー定義集約関数 (UDAF) にまとめる

上記の回帰係数を求める処理は、そこまで複雑ではないとはいえ何度も書くとなると少し大変です。
例えば回帰係数 $${\hat{a}}$$ だけでなく定数項 $${\hat{b} = \bar{y} - \hat{a} \bar{x}}$$ も求める必要が出た場合、`AVG(y) - SAFE_DIVIDE(COVAR_SAMP(x, y), VAR_SAMP(x)) * AVG(x)` のようにもう一度回帰係数の計算処理を書くか、またはクエリを分ける必要が生じます。

そこで、回帰係数の計算処理をユーザー定義集約関数(User Defined Aggregate Functions, UDAF)にまとめることで処理を使いまわせるようにしてしまおうと思います。

回帰係数・定数項・予測値の集計処理を UDAF にする

クエリ内で下記のように記述することで、`LINEAR_REG_COEF` という名前で線形回帰係数を計算するための集約関数を定義できます:

CREATE TEMP AGGREGATE FUNCTION LINEAR_REG_COEF(x FLOAT64, y FLOAT64)
RETURNS FLOAT64
AS (
  (SAFE_DIVIDE(COVAR_SAMP(x, y), VAR_SAMP(x))) 
)

`CREATE TEMP AGGREGATE FUNCTION`  が UDAF を定義するための命令文であり、`TEMP` は利用可能な範囲がクエリ内のみであることを意味します。

これを用いて、定数項 $${\hat{b} = \bar{y} - \hat{a} \bar{x}}$$ についても UDAF として以下のように定義してしまおうと思います:

CREATE TEMP AGGREGATE FUNCTION LINEAR_REG_CONST(x FLOAT64, y FLOAT64)
RETURNS FLOAT64
AS (
  (AVG(y) - LINEAR_REG_COEF(x, y) * AVG(x)) 
)

途中で先ほど定義した回帰係数を算出する UDAF `LINEAR_REG_COEF` を使いました。このように、UDAF の中から先に定義しておいた別の UDAF を呼び出すこともできます。

ついでなので、$${x}$$ が指定した値のときの予測値 $${\hat{y} = \hat{a} x + \hat{b}}$$ を集計する処理も UDAF にしてしまおうと思います:

CREATE TEMP AGGREGATE FUNCTION LINEAR_REG_PRED(x FLOAT64, y FLOAT64, x_pred FLOAT64 NOT AGGREGATE)
RETURNS FLOAT64
AS (
  (LINEAR_REG_COEF(x, y) * x_pred + LINEAR_REG_CONST(x, y))
)

上記では、予測対象の $${x}$$ の値を入れる `x_pred` について、引数の定義の際に `x_pred FLOAT64 NOT AGGREGATE` と記述しています。
これは、テーブルのカラムを指定する `x`, `y` とは異なり `x_pred` には定数を入れるためです。

実行

以上をまとめると、以下のように実行します:

CREATE TEMP AGGREGATE FUNCTION LINEAR_REG_COEF(x FLOAT64, y FLOAT64)
RETURNS FLOAT64
AS (
  (SAFE_DIVIDE(COVAR_SAMP(x, y), VAR_SAMP(x))) 
);

CREATE TEMP AGGREGATE FUNCTION LINEAR_REG_CONST(x FLOAT64, y FLOAT64)
RETURNS FLOAT64
AS (
  (AVG(y) - LINEAR_REG_COEF(x, y) * AVG(x)) 
);

CREATE TEMP AGGREGATE FUNCTION LINEAR_REG_PRED(x FLOAT64, y FLOAT64, x_pred FLOAT64 NOT AGGREGATE)
RETURNS FLOAT64
AS (
  (LINEAR_REG_COEF(x, y) * x_pred + LINEAR_REG_CONST(x, y))
);

WITH sample_data AS (
  SELECT 1 AS x, 2 AS y UNION ALL
  SELECT 2 AS x, 5 AS y UNION ALL
  SELECT 3 AS x, 3 AS y UNION ALL
  SELECT 4 AS x, 9 AS y UNION ALL
  SELECT 5 AS x, 11 AS y
)

SELECT
  LINEAR_REG_COEF(x, y) AS linreg_coef,
  LINEAR_REG_CONST(x, y) AS linreg_const,
  LINEAR_REG_PRED(x, y, 6) AS linreg_pred
FROM sample_data

`linreg_coef`, `linreg_const` がそれぞれ回帰係数と定数項です。
`LINEAR_REG_PRED(x, y, 6)` により集計した `linreg_pred` は、$${x=6}$$ のときの予測値になります。

発展: 時系列データの場合

一通りシンプルなデータでの処理方法がわかったので、当初やりたかったことに戻って時系列データに適用したいと思います。

ここでは、データの上昇・下降トレンドを、月毎に出そうと思います。
下記は、本記事冒頭の時系列データにおいて R を用いて月ごとに回帰直線を引いたものです。

以下では、BigQuery の UDAF を用いて月ごとの線形回帰係数を求められるようにします。

日付データ用の回帰集約関数の作成

時系列データに適用する場合、日付をそのまま説明変数 (`x`) として入力することができないので、適当な数値に変換します。
最も簡単なのは、`UNIX_DATE` 関数を適用することです。これにより、1970-01-01 からの経過日数に変換します。

これを用いて先ほどの回帰係数を計算する UDAF を書き換えて、引数 x に DATE 型を取れるようにしたのが以下です:

CREATE AGGREGATE FUNCTION `project_id.dataset_name.LINEAR_REG_COEF_DATE`(x DATE, y FLOAT64)
RETURNS FLOAT64
AS (
  (SAFE_DIVIDE(COVAR_SAMP(UNIX_DATE(x), y), VAR_SAMP(UNIX_DATE(x)))) 
)

なお、今回は UDAF を指定したプロジェクト・データセット下に保存することで、別のクエリを書く際にも利用できるようにしました。

ついでなので、定数項についても同様に時系列版を作成し保存してしまいます:

CREATE AGGREGATE FUNCTION `project_id.dataset_name.LINEAR_REG_CONST_DATE`(x DATE, y FLOAT64)
RETURNS FLOAT64
AS (
  (AVG(y) - `project_id.dataset_name.LINEAR_REG_COEF_DATE`(x, y) * AVG(UNIX_DATE(x))) 
)

予測については、「データ中の集計グループ内の最新の日付から、`n_days_ahead` 日後の値を予測する」という形にして以下のように定義しました:

CREATE AGGREGATE FUNCTION `project_id.dataset_name.LINEAR_REG_PRED_DATE`(x DATE, y FLOAT64, n_days_ahead INT64 NOT AGGREGATE)
RETURNS FLOAT64
AS (
  (`project_id.dataset_name.LINEAR_REG_CONST_DATE`(x, y) + `project_id.dataset_name.LINEAR_REG_COEF_DATE`(x, y) * (UNIX_DATE(MAX(x)) + n_days_ahead) )
)

実行

これらを用いて月毎に集計するクエリを書くと、以下のようになります:

SELECT
  DATE_TRUNC(date, MONTH) AS month,
  `project_id.dataset_name.LINEAR_REG_COEF_DATE`(date,
    y) AS linreg_coef,
  `project_id.dataset_name.LINEAR_REG_CONST_DATE`(date,
    y) AS linreg_const,
  `project_id.dataset_name.LINEAR_REG_PRED_DATE`(date,
    y,
    30) AS linreg_pred
FROM
  `project_id.dataset_name.sample_data_for_note_on_udaf`
GROUP BY
  month

`linreg_coef` が月ごとの回帰係数(≒ トレンド)になります。その月の中では1日あたりどれくらい増加・減少しているかを表しています。
`linreg_const` は定数項ですが、今回の場合は 1970-01-01 時点の予測値を意味するので、あまり出力する意味がないかもしれません。

`linreg_pred` は月末から30日後の予測値です。
例えば 3 行目(`month = 2014-03-01`)では、2014-03-31から30日後なので2014-04-30における予測値を表しています。「3月中の増減トレンドがそのまま4月以降も続くのであれば、4月末の値は129.88になる」ということを意味しています。

まとめ

今回は BigQuery 関数を組み合わせて線形回帰の計算を行い、さらに処理内容をユーザー定義集約関数(UDAF)にまとめることで使いまわせるようにしました。
本記事の例では日次の時系列データに対して月毎のトレンドを集計するために利用しましたが、そのほかユーザーごとや属性ごとのトレンドを定量的に見る際などにも使えます。

また、線形回帰以外にも、何度も使う集約処理については UDAF にすると重複して処理を書く手間が省けて便利になる場面があるかもしれません。

なお、BigQuery 上で線形回帰を行う方法としては、ほかに BigQuery ML (BQML)の `LINEAR_REG` もあります。

しかし、BQML ではモデル作成・予測ステップごとにクエリを書く必要があり、今回の想定ユースケースに対しては過剰に手数が多く煩雑になってしまいます。
なにより、今回月ごとのトレンドを見る際にやったような、グループごとの統計量を一度に算出することが不可能です。
一変数の線形回帰係数や予測値を求めるくらいの簡単な処理であれば、集約関数を使って書いてしてしまう方が圧倒的に良いでしょう。


いいなと思ったら応援しよう!