見出し画像

最適輸送とPOTまとめ

はじめに

こんにちは!ハルです.今回は自分の見直し用に最適輸送(Optimal Transport)や最適輸送問題を解くことができるライブラリであるPOT: Python Optimal Transportについて概要を解説します!


最適輸送問題とは?

最適輸送問題とは,確率分布の比較を行う1つの手法になっています.簡単にいうと,ある確率分布を移動させて,比較したい確率分布に一致させるためにかかるコストを計算しようというものになっています.よく使われる例え話としては下の図の青い土山(確率分布$${\alpha}$$})を赤い土山(確率分布$${\beta}$$)に土を運搬する時に最も無駄のない時のコストになります.
連続型確率分布の場合は,確率密度関数が上記の山に当てはまり,離散分布の場合は,各点$${x_i}$$に質量$${a_i}$$の山が存在するという風に考えます.

離散分布型最適輸送問題

以降,離散分布型最適輸送問題について説明いたします.全然厳密ではないので,気になる場合はぜひ論文等を参照してください.すみません.
最適輸送問題の目的関数と制約条件は以下のようになります.

$$
\begin{align} \underset{\bm{P}} {\text{minimize}} \quad & \sum_{i=1}^{n}\sum_{j=1}^{m} \bm{C}_{ij} \bm{P}_{ij}\\ \text{s.t.} \quad & \bm{P}_{ij} \geq 0 \quad \forall i \in [n], \forall j \in [m] \\ & \sum_{j=1}^{m}\bm{P}_{ij} = \bm{a}_{i} \quad \forall i \in [n] \\ & \sum_{i=1}^{n}\bm{P}_{ij} = \bm{b}_{j} \quad \forall j \in [m] \end{align}
$$


ここで,$${\bm{P} \in \mathbb{R}^{n \times m}}$$で$${\bm{P}{ij}}$$は輸送元の$${i}$$番目の点から,輸送先の$${j}$$番目の点への輸送量を表しており,輸送行列と呼びます.また,$${\bm{C}{ij}}$$は輸送するコストを表し,$${\bm{a}_i}$$は輸送元の$${i}$$番目の点の質量,$${\bm{b}_j}$$は輸送先の$${j}$$番目の点の質量を表しています.
式(2)の制約条件は,全ての輸送量が正となる非不制約であり,式(3)(4)は質量の保存則となっています.

エントロピー正則化つき最適輸送問題

先ほど紹介した,最適輸送問題は解くのにおおよそ,点の数の3乗の時間がかかってしまうため,大規模なデータに対しては適用できないという問題点が存在しました.それを解決するために提案されたのが以下のエントロピー正則化つき最適輸送問題になります.以下に,エントロピー正則化つき最適輸送問題の目的関数と制約条件を示します.

$$
\begin{align} \underset{\bm{P}} {\text{minimize}} \quad & \sum_{i=1}^{n}\sum_{j=1}^{m} \bm{C}_{ij} \bm{P}_{ij} - \lambda H(\bm{P}) \\ \text{s.t.} \quad & \bm{P}_{ij} \geq 0 \quad \forall i \in [n], \forall j \in [m] \\ & \sum_{j=1}^{m}\bm{P}_{ij} = \bm{a}_{i} \quad \forall i \in [n] \\ & \sum_{i=1}^{n}\bm{P}_{ij} = \bm{b}_{j} \quad \forall j \in [m] \end{align}
$$

詳しくは解説しませんが,このエントロピー正則化つき最適輸送問題を高速で解くためにシンクホーンアルゴリズムが提案されています.

POTを使った実装例

今回は主に上記の最適輸送問題とエントロピー正則化つき最適輸送問題の実装例について説明していこうと思います.

データ生成

ot.dataset_2D_samples_gauss(n, m, sigma, random_state)

パラメータ

  •  n:サンプル数

  •  m:ガウス分布の平均

  •  sigma:ガウス分布の分散共分散行列

  •  random_state:シード値

このコードを用いるとガウス分布からサンプリングした点群を得ることができます.実際の使用例を以下に示します.

import ot
import numpy as np
import matplotlib.pyplot as plt

n = 100 # 点群中の点の数

mu_s = [0, 0] # 輸送元の平均
mu_t = [4, 4] # 輸送先の平均

sigma_s = [[1, 0], [0, 1]] # 輸送元の分散 
sigma_t = [[1, 0], [0, 1]] # 輸送先の分散

# データの生成
xs = ot.datasets.make_2D_samples_gauss(n=n, m=mu_s, sigma=sigma_s, random_state=0)
xt = ot.datasets.make_2D_samples_gauss(n=n, m=mu_t, sigma=sigma_t, random_state=1)

# プロット
plt.scatter(xs[:, 0], xs[:, 1], color='r', label='Source')
plt.scatter(xt[:, 0], xt[:, 1], color='b', label='Target')
plt.show()

最適輸送問題を解く

ot.emd(a, b, numItermax, log, center_dual, numThreads)

パラメータ

  • a:輸送元の点群の質量

  • b:輸送先の点群の質量

  • numItermax:最大イテレーション

  • log:コストと双対変数を返すようにする

  • center_dual:双対変数を正規化

  • numThreads:並列処理する場合のスレッド

出力は最適輸送行列になります.上記のlogの通りパラメータ次第で双対変数の値と総輸送コストも出力することができます.

# 重みの計算
a = ot.unif(n)
b = ot.unif(n)

# コスト行列の計算
M = ot.dist(xs, xt)

# POTの計算
P, log = ot.emd(a, b, M, log=True)

# プロット
plt.figure()
plt.imshow(P, cmap='Blues')
plt.title('Optimal Transport')
plt.colorbar()
plt.show()


輸送行列のヒートマップ

エントロピー正則化つき最適輸送問題を解く

ot.sinkhorn(a, b, M, reg, method, numItermax, stopThr, verbose, log, warn, warmstart)

パラメータ(既出のものは除く)

  • reg:正則化パラメータ

  • method:ソルバーで使用するメソッド

  • stooThr:収束条件

  • verbose:Trueにすると反復中に情報を出力する

  • warn:Trueの時アルゴリズムが収束しない場合に警告を出す

  • warmstart:双対変数の初期値を決定する

出力は最適輸送行列になります.また,最適輸送問題と同様で,パラメータによって双対変数およびそう輸送コストを表示することもできます.

# 重みの計算
a = ot.unif(n)
b = ot.unif(n)

# コスト行列の計算
M = ot.dist(xs, xt)

# POTの計算
P, log = ot.sinkhorn(a, b, M, reg=0.1, log=True)

# プロット
plt.figure()
plt.imshow(P, cmap='Blues')
plt.title('Optimal Transport with Entropy Regularization')
plt.colorbar()
plt.show()
エントロピー正則化時の輸送行列のヒートマップ

おわりに

ここまでお読みいただきありがとうございました.自分のメモ用なのでわかりにくい部分もあると思いますが,ぜひコメントで質問してください!

参考文献

いいなと思ったら応援しよう!