LeetCode #4 "median-of-two-sorted-arrays"
#LeetCode #難易度Hard #BinarySearch
上から4つ目。難易度がHardとなっている。
問題文
There are two sorted arrays nums1 and nums2 of size m and n respectively.
Find the median of the two sorted arrays. The overall run time complexity should be O(log (m+n)).
You may assume nums1 and nums2 cannot be both empty.
Example 1:
nums1 = [1, 3]
nums2 = [2]
The median is 2.0
Example 2:
nums1 = [1, 2]
nums2 = [3, 4]
The median is (2 + 3)/2 = 2.5
2つのソート済み整数配列が与えられる。それぞれの配列長はm,nである。これらの両方の整数配列の要素の中央値を求めよ。 ただし時間計算量がO(log (m+n))になるようにせよ。とのこと。
考えた
愚直にやると2つのリストをマージしてソートすることになるが、これだと計算量がO(log (m+n))にならない。計算量がO(log(m))の探索アルゴリズムで、かつソート済み配列を扱うものといえば、バイナリサーチだろう。これは2つの配列を行う変則型といったところ?
2つのリストを合体したと仮定したときの中央に位置するインデックスを割り出す。2つのリストのうち長い方を選び、その中央位置インデックスの要素と、短いほうのリストの最小値を比較する。ここで、長い方のリストの要素のほうが小さければ、それがそのまま中央値となる。逆に、小さかったら、、、どうしよう?
ここで思考停止した。わからん。どうしたらいいのだろう?結局、1つずつ辿ることになるのだろうか?
長時間悩んだ末、どうしても分からんかったので、泣く泣く模範解答を見てみた。
考え方(模範解答より)
日本語訳してみた。英語だけど、わかりやすい文章だった。(小並感)
ソート済み配列内における中央値は、その配列を2等分したときの、断面の要素になる。この問題を解くうえでこの理解が重要になる。今回の問題では2つのソート済み配列が与えられる。これら2つをマージして、その上でそれを2等分したときの断面の要素が、今回求めたい数値ということになる。
今回の問題で与えられる2つの配列をそれぞれ、A、B、と名付ける。Aの長さはmで、Bの長さはnである。
Aを適当な位置でカットすることを考える。このときのカットした位置をiとする。カットした断面より左をleft_A、右をright_Aと名付ける。こんな具合になる↓。
left_A | right_A
A[0], A[1], ..., A[i-1] | A[i], A[i+1], ..., A[m-1]
Bについても同じように、適当な位置でカットする。カットした位置をjとする。カットした断面より左をleft_B、右をright_Bと名付ける。
left_B | right_B
B[0], B[1], ..., B[j-1] | B[j], B[j+1], ..., B[n-1]
left_Aとleft_Bを1つのセットにまとめる。これをleft_partと名付ける。そして、right_Aとright_Bも1つのセットにまとめる。これをright_partと名付ける。
left_part | right_part
A[0], A[1], ..., A[i-1] | A[i], A[i+1], ..., A[m-1]
B[0], B[1], ..., B[j-1] | B[j], B[j+1], ..., B[n-1]
left_partとright_partが、下記2条件を満たしているとすると、
1.len(left_part)=len(right_part)
2.max(left_part) <= min(right_part)
それは{A,B}の全要素が2等分されていて、かつ、それら2等分されたうちの片方の全要素の値がもう片方の全要素の値より大きいという状態である。
このとき中央値が下記式で求められる。
median = (max(left_part) + min(right_part)) / 2
なぜなら中央値とは、ソート済み配列を2等分したときの断面の要素だから。
len(left_part)、len(right_part)は、下記のように書き表すことができる。
len(left_part) = (i - 1 - 0) + (j - 1 -0) = i + j - 2
len(right_part) = (m - 1 - i) + (n - 1 - j) = m + n - i - j - 2
これをさらに変形する。
len(left_part) = len(right_part)
=> i + j = m + n - i - j
=> j = (m + n)/2 - i
ここの "j = (m + n)/2 - i"という式は、変形元が"len(left_part) = len(right_part)"である。そのため、これは分断した左右の配列長が等しい、つまり全体の長さ(m+n)が偶数であるケースである。今回の問題は、全体の長さが奇数になるケースも存在する。その場合、変形元は"len(left_part) = len(right_part)+1"となり、それを変形すると"j = (m + n + 1)/2 - i"となる。下記URLにも記載されているが、m+nが偶数の場合は、奇数の式で同じ値が算出できてしまうので、常に奇数の場合の式を採用すればよいということになる。
m+nが奇数の場合 j = (m + n + 1)/2 - i
m+nが偶数の場合 j = (m + n)/2 - i = (m + n + 1)/2 - i
中央値を求める際left_partとright_partが満たすべき2条件をまとめると、下記のようになる。
1, i + j = m - i + n - j
n>=mのとき iの範囲はi=0~mで、そのときのjは、j=(m+n+1)/2 - i
2,B[j - 1] <= A[i] かつ A[i -1] <= B[j]
n>=mが前提なのはjを負の値にしたくないからである。与えられた2つの配列の長さがm>=nなら、スワップしてやればよい。
最終的に下記のことが言える。
下記条件を満たすiを0~mの範囲から探索すればよい。
B[j - 1] <= A[i] かつ A[i -1] <= B[j]
ここで、j=(m+n+1)/2 - iである。
具体的な探索方法(模範解答より)
次に探索方法だが、これには二分探索法を用いる。探索範囲である0~mに対して、二分探索法を適用するということである。ここからは結構簡単。
手順1
imin = 0, imax = m とする。これが探索開始位置となる。探索範囲は[imin, imax]と記述する。
手順2
2つの配列をカットする切れ目となるインデックスi,jを式で求める。
i = (imin+imax)/2, j = (m+n+1)/2 -i
手順3
i,jは、すでにlen(left_part)=len(right_part)を満たすものである。もうひとつの条件である、max(left_part) <= min(right_part)を満たすかどうか判定する。
(1) B[j-1] <= A[i] かつ A[i-1] <= B[j] のとき、これは条件を満たすi,jということなので、探索をここで打ち切る。手順4へ。
(2) B[j-1] > A[i]ならば、Aを大きくしなければならないということを意味する。Aをカットすべき位置がiよりも右にあるということなので、ここよりも右を探索範囲にして、再探索する。つまり[i+1, imax]を新たな探索範囲として、手順2へ。
(3) A[i -1] > B[j]ならば、Aを現在よりも小さくする必要があるということを意味する。Aをカットすべき位置がiよりも左にあるということなので、ここよりも左側を探索範囲として、再探索する。つまり、[imin, i-1]を新たな探索範囲として、手順2へ。
手順4
中央値を計算する。
m + nが奇数なら、median = max(A[i-1], B[j-1])
m + nが偶数なら、median = (max(A[i-1], B[j-1]) + min(A[i], B[j])) / 2
計算量について
時間計算量について。与えられた2つの配列の長さがm,nのとき、その小さいほう、すなわちmin(m,n)が探索範囲になる。その範囲を二分探索法によって1ループごとに探索範囲を半分ずつ絞っていくので、計算量はO(log(min(m,n)))になる。1ループ内での処理は定数演算のみなので、無視できる。
空間計算量について。与えられたデータのサイズや探索経過によらず、変数の個数が常に9個であるので、空間計算量はO(1)である。
コード
模範解答が理解できたところで、改めてコードを書いてみた。
class Solution {
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
// 小さい方をarray_aとする
auto m = nums1.size();
auto n = nums2.size();
auto array_a = nums1;
auto array_b = nums2;
// 大小が逆ならswap
if(m > n){
n = nums1.size();
m = nums2.size();
array_b = nums1;
array_a = nums2;
}
auto imin = 0;
auto imax = m;
while(1){
auto i = (imin + imax) / 2;
auto j = (m + n + 1) / 2 - i;
if(i < imax && array_a[i] < array_b[j - 1]){
// 求めるiは、いまのiよりも右にある
imin = i + 1;
continue;
}else if(i > imin && array_b[j] < array_a[i - 1]){
// 求めるiは、いまのiよりも左にある
imax = i - 1;
continue;
}else{
int left_max;
if(i == 0){
// Aのleftが空のケース
left_max = array_b[j - 1];
} else if(j == 0){
// Bのleftが空のケース
left_max = array_a[i - 1];
} else{
left_max = std::max(array_a[i - 1], array_b[j - 1]);
}
if((m + n) % 2 != 0){
return left_max;
}
int right_min;
if(i == m){
// Aのrightが空のケース
right_min = array_b[j];
}else if(j == n){
// Bのrightが空のケース
right_min = array_a[i];
} else{
right_min = std::min(array_a[i], array_b[j]);
}
return (left_max + right_min) / 2.0;
}
}
throw std::logic_error("cant resolve");
}
};
解き方が理解できればコードをスラスラ書けるか、といわれると意外とそうでもない。実は理解できていなかったりする部分があったりするし、解説だけだと具体性が足りなかったりする。