クイックソートは難しい? : 文科省研修教材のコードには虫がいるよ
「高校教員のためのPython 入門から応用まで」に,関数の再帰呼び出しについて書こうと思って,文科省の情報Iの教員用研修教材に掲載されているクイックソートを見てみた。クイックソートを載せるかどうかを検討するために。
そうしたら,虫がいた。
該当箇所はこれ。
まず,この通り打ち込んで動かしてみた。
ちゃんと並び変わっている。
ではどこに虫?
それを説明する前に,再帰的呼び出しについて説明しておこう。この教材にはまったく説明が書かれていないので。
再帰呼び出し,もしくは再帰的定義とは,関数定義の中で,その関数自身を呼び出すものだ。
といっても意味がわからないだろう。すんなり理解できる生徒がどのくらいいるか,ちょっと見当がつかない。今までの授業で再帰呼び出しは扱ったことがないからだ。といって,まったくゼロではなく,理数科の課題研究でレクチャーしたことはある。
よく使われるのが,階乗の計算だ。整数 n の階乗は,次のように計算される。
n != n(n-1)(n-2)・・・・・2・1
高校1年生の数学の内容だ。これを次のように書くことができる。
n! = n・(n-1)!
すると,nの階乗を計算する関数を fact(n) とすると
fact(n) = n * fact(n-1)
ということになる。
def fact(n):
if n == 1:
return 1
else:
return n * fact(n-1)
fact(n) の関数定義の中で,自分自身の fact() を使っている。
fact(n)を計算するために fact(n-1) を使う。
fact(n-1)を計算するために fact(n-2) を使う。
これを繰り返すのだが,いつかは終わらなくてはならない。それが if n == 1 : だ。1!= 1 なので,ここで終わる。
授業でクイックソートを扱うか。筆者だったら,教科書に載っていなければやらない。再帰呼び出しの概念が簡単ではないと思うからだ。
しかし,載っていたらやらざるを得ない。ならば,まずこの「階乗」の説明をする。これを70%くらいの生徒は理解できると期待したい。
さて,問題のクイックソート。ここで再帰呼び出しを使っている。
if start < i - 1:
quicksort(a, start, m-1)
if end > j + 1:
quicksort(a, m+1, end)
では,どうやって並べ替えをしているのだろうか。テキストには何も説明がないので,プログラムを読むしかない。しかし,そう簡単に読み取れるシロモノではない。そこで,データ量を減らして,途中の状態を print() で表示することにした。データの後ろを切って,
a = [7,22,11,34,17,52,26,13]
としてみた。
あれれ? ダメじゃん。・・・ 虫さん,虫さん,どこかな〜
あらかじめお断りをしておく。ここから先は難しい・ややこしいから,時間のあるときにじっくり取り組んでもらいたい。
プログラムコードを読む限り,おかしなところはなさそうだが,そもそも何をやりたいのかがわからない。「リスト内の一つのデータを軸として,大小2つに分割した後,分割したデータに対して同じ 処理を再度行う」とだけしか書いてない。ここの「同じ処理」は何をしているのか説明がないのだ。
とはいえ,プログラムコードを読むと,次のようにしたいらしいことはわかる。
中央の値をとり,その値より大きいものは左(後ろ)へ,その値より小さい
ものは右(前)にもっていくように入れ替えていく。
入れ替えができたら,前半と後半について同じことをする(再帰呼び出し)
途中経過を表示するために状態を知りたいところにprint文を入れていく。何度かやり直して,次のようになった。
called [7, 22, 11, 34, 17, 52, 26, 13] ・・・ quicksort() を呼び出したとき
start= 0 end= 7 m= 3 a[m]= 34 ・・・ そのときの各値
i,jを探索後 i= 3 j= 7 ・・・ i, j を探索したあとの値
swap後 [7, 22, 11, 13, 17, 52, 26, 34] ・・・ いれかえたあと
m= 7 i= 4 j= 6 start= 0 end= 7
i,jを探索後 i= 5 j= 6
swap後 [7, 22, 11, 13, 17, 26, 52, 34]
m= 7 i= 6 j= 5 start= 0 end= 7
start < i-1 : True // end > j+1 : True ・・・ while ループを抜けたときの値
call quicksort(a,start,m-1) ・・・ start < i-1 : True なのでこれを実行
called [7, 22, 11, 13, 17, 26, 52, 34] ・・・ 再帰呼び出しでの a の状態
start= 0 end= 6 m= 3 a[m]= 13 以下同様
i,jを探索後 i= 1 j= 3
swap後 [7, 13, 11, 22, 17, 26, 52, 34]
m= 1 i= 2 j= 2 start= 0 end= 6
さて,どこがおかしい?
最初は,リスト全体について quicksort() を呼ぶ。そのとき
start= 0 end= 7 m= 3 a[m]= 34 ・・・ そのときの各値
になっている。データは8個ある。インデックスは0スタートなので,m=3 は3番目(普通に数えると4番目)の値 34 だ。i を start の0,j をend の7 にして,
while a[i] < a[m]:
i = i + 1
while a[j] > a[m]:
j = j - 1
で,m 番目の左の値を見ていって m番目の値 a[m] =34 より小さければ i を進めていく。ということは,a[m]以上の値を探していくわけだ。
[7, 22, 11, 34, 17, 52, 26, 13] なので,34より左に a[m]以上の値はないから,i は m になる。
また,m番目の右は,a[m]以下の値を探している。最初の 13 が該当するので,jは7だ。この2つを入れ替える。
i,jを探索後 i= 3 j= 7 ・・・ i, j を探索したあとの値
swap後 [7, 22, 11, 13, 17, 52, 26, 34] ・・・ いれかえたあと
もし,もとのデータが [7, 22, 41, 34, 17, 52, 26, 13] であれば,34 の左に,34より大きい 41 があるから 41 と 13 を入れ替えて
i,jを探索後 i= 2 j= 7
swap後 [7, 22, 13, 34, 17, 52, 26, 41]
となる。i と m が等しくなるかどうか,ここにひとつ鍵がある。
次に,
if i == m:
m = j
elif j == m:
m = i
i = i + 1
j = j - 1
として,while の繰り返しに戻る。入れ替えた後は
m= 7 i= 4 j= 6 start= 0 end= 7
となっているので,while (i < j) の条件を満たして,同じことを繰り返すわけだ。こんどは
i,jを探索後 i= 5 j= 6
swap後 [7, 22, 11, 13, 17, 26, 52, 34]
m= 7 i= 6 j= 5 start= 0 end= 7
となるから,while の繰り返し条件 i < j を満たさず,while を抜けて次に進む。
start < i-1 : True // end > j+1 : True ・・・ while ループを抜けたときの値
なので,
if start < i - 1:
quicksort(a,start,m-1)
で再帰呼び出しとなる。
called [7, 22, 11, 13, 17, 26, 52, 34] ・・・ 再帰呼び出しでの a の状態
start= 0 end= 6 m= 3 a[m]= 13 以下同様
結構ややこしい。画面で追うのは大変なので,プリントアウトして,プログラムコードと途中経過を比べた。(前述の経過)
さあ,おかしいのはどこ?
再帰呼び出しで,
called [7, 22, 11, 13, 17, 26, 52, 34] ・・・ 再帰呼び出しでの a の状態
start= 0 end= 6 m= 3 a[m]= 13 以下同様
となっているところだ。start=0 , end=6 だから,処理対象は
[7, 22, 11, 13, 17, 26, 52]
である。これは左側を再帰呼び出ししたところ。この右に34がある。これを並べ替えていくとどうなるか。
34より大きい 52 が 34 より右に来ることはなくなる。
結果をもう一度見てみよう。
34 が右端にいったままおいてけぼりになっている。
さあ,プログラムコードを見て,おかしなところを探そう。
if i == m:
m = j
これ,なんだ?
その前に,i と m が同じになって,a[i] つまり a[m] =34 とa[j]=13 を入れ替えてしまったので,mをあらためて j にしたんだね。そうでないと a[m] は 13 になるから,基準値が変わってしまう。mを j にしたから m=7 になって,a[m]=34 だから基準値は変わらない。
swap後 [7, 22, 11, 13, 17, 52, 26, 34] ・・・ いれかえたあと
m= 7 i= 4 j= 6 start= 0 end= 7
これ,ほんとにいいのか?
さあ,もう一度,クイックソートのアルゴリズムを考えよう。Web上で「クイックソート」を検索してみるのもいいだろう。
クイックソートのアルゴリズム(今の場合)
(1) 基準の値を決める。ここでは,並んでいる値の中央の値。
偶数個なら,前半の最後の値。
(2) i を左端(start)から始めて,基準値以上のものを探す。
(3) j を右端(end)から始めて,基準値以下のものを探す。
(4) (2)の値と(3) の値を入れ替える。
(5) (2)〜(4) を続けて,i と j が交差したら ( if i >= j )探索と入れ替えは終わる。
break 文のところ。
(6) これで,左の方には基準値以下の値,右の方には以上の値だけが集まる
(7) 左のグループに対して再帰的に(1)〜(6)を行う。
右のグループにも再帰的に(1)〜(6) を行う。
では,このコードはこのようになっているか。なっていない。
i が m と等しくなったために,入れ替えたあと m = j とした。これが間違いだったのだ。そして,再帰呼び出しの時に,quicksort(a,start,m-1) としたこと。このために,(6)が保たれなくなってしまったのだ。
修正しよう。基準値は key = a[m] として変数に代入し,以後変化させない。したがって,m = j とする必要もない。
再帰呼び出しのとき,(7) に分けるのは i と j が交差したときの位置だから m ではなく i または j のところにする。
# 改訂したつもりの版
def quicksort(a, start, end):
m = int((start+end)/2)
i = start
j = end
key = a[m]
while(i < j):
while a[i] < key:
i = i + 1
while a[j] > key:
j = j - 1
if i >= j:
break
temp = a[i]
a[i] = a[j]
a[j] = temp
i = i + 1
j = j - 1
if start < i - 1:
quicksort(a, start, i - 1)
if end > j + 1:
quicksort(a, j + 1, end)
a = [7, 22, 11, 34, 17, 52, 26, 13]
print(" ソート前 ", a)
quicksort(a, 0, len(a)-1)
print(" ソート後 ", a)
結果は
ソート前 [7, 22, 11, 34, 17, 52, 26, 13]
ソート後 [7, 11, 13, 17, 22, 26, 34, 52]
できた。
いや,ほんとにこれで大丈夫?他のもやってみよう。
あれれれ? ダメじゃん。まだ他の虫がいる。
途中経過をだす。
ソート前 [10, 12, 13, 12]
called [10, 12, 13, 12]
1 a[m]= 12
swap 1 3 [10, 12, 13, 12]
i,j= 2 2
called [10, 12]
0 a[m]= 10
i,j= 0 0
ソート後 [10, 12, 13, 12]
i と j が同じ値 2 になっていて,値13を指している。それから前半の[10,12]に対して再帰呼び出しをしているが,後半の[13,12]をやっていない。
if end > j + 1:
quicksort(a, j + 1, end)
なのだが,end > j + 1 が成り立たないので呼び出していないのだ。
じゃあ, if end > j にしてみる? うまくいかない。
あれこれ試して悩んだ揚げ句,やっとわかった。
[10,12] と [13,12] の両方にわけて再帰呼び出しがなされないのは,i , j が同じ値になっているのと,end > j + 1 が成り立たないからなのだが,なぜそうなるかというと,while の中の最後のところで
i = i + 1
j = j - 1
としているからだ。ここで,i と j が同じになり,while を抜ける。すると,再帰呼び出しがうまくいかない。
ということは,while の繰り返し条件 i < j がまずいのだ。これを i <= j とすれば,jの値がひとつ進んで,再帰呼び出しの条件 end > j + 1 が成り立つ。
while(i<j) を while i <= j に直して(カッコはいらないのと,不等号の前後にスペースというのも直した)でき上がり! ・・・ かな?
先ほど,改訂したつもりの版でうまくいかないことを発見したのは, a = [10, 12, 13, 12] を適当に思いついてやってみたわけではない。コンピュータで乱数列を作ってやってみたのだ。
こんな風にする。numpy を import しておく。
a = list(np.random.randint(1, 20, 10))
b = sorted(a)
print(" ソート前 ", a)
quicksort(a, 0, len(a)-1)
print(" ソート後 ", a)
print(" ソート正 ", b)
print(a == b)
これを実行していくと,ほどなく a == b が False となるものが見つかる。
ソート前 [1, 16, 12, 10, 11, 1, 1, 4, 15, 6]
ソート後 [1, 1, 4, 6, 10, 1, 11, 12, 15, 16]
ソート正 [1, 1, 1, 4, 6, 10, 11, 12, 15, 16]
False
while(I<j) を while i <= j に直した最終版では,これをループで10000回やってみたが全部 True だった。
最後に補足。
Python ならではの記法を使うと,
m = int((start+end)/2) は m = (start + end) // 2
と書ける。値の入れ替えなども簡略な記法がある。しかし,情報の授業では「Pythonを教える」のではなく,「Pythonを使ってプログラミングを教える」ことと,大学入試共通テストの疑似言語の仕様を考慮して,Python使いの人には「トロい」と思われる方法でも,初歩的な方法を扱っていくのがよいだろう。商の演算子 // は使うが。(いや,int() の方がいいかな)
Pythonならではの記法などを使うとどうなるか,三重大の奥村先生のページにお手本があるので参考にされたい。