強化学習に必要な「Qテーブル」と「離散値で表す関数」をつくるには?
こんにちは!
ぷもんです。
前回、離散値ってなんや?というnoteで
強化学習で必要な離散値とは何か?なぜ必要なのか?
について書きました。
今回は具体的に離散値に変換していきます。
今回やるのはこちらのコードです。
q_table = np.random.uniform(low=-1, high=1, size=(4 ** 4, env.action_space.n))
def bins(clip_min, clip_max, num):
return np.linspace(clip_min, clip_max, num + 1)[1:-1]
def digitize_state(observation):
cart_pos, cart_v, pole_angle, pole_v = observation
digitized = [np.digitize(cart_pos, bins=bins(-2.4, 2.4, 4)),
np.digitize(cart_v, bins=bins(-3.0, 3.0, 4)),
np.digitize(pole_angle, bins=bins(-0.5, 0.5, 4)),
np.digitize(pole_v, bins=bins(-2.0, 2.0, 4))]
return sum([x * (4 ** i) for i, x in enumerate(digitized)])
やっていることは
・Qテーブルという表のようなものを作る
・フィードバックで得られる値を4の4乗の256の離散値で表す関数をつくる
の2つです。
ではやって行きます!!
・Qテーブルってなんや?
強化学習の手法のうち、Q学習というものをやっています。
Q学習ではある時間tのある状態sである行動aを取った時どうなるかを
関数で表したQ関数というものを作ります。
Q関数を表で表したものがQテーブルです。
q_table = np.random.uniform(low=-1, high=1, size=(4 ** 4, env.action_space.n))
ではQテーブルを作っています。
Qテーブルという表みたいなものを作っているのはわかったけど
右辺の意味が気になります...。
・np.random.uniform(low, high, size)
np.random.uniform(low, high, size)は
low以上high未満の一様乱数をsize個つくるという意味です。
一様乱数はランダムな数という意味です。
今回の場合は
q_table = np.random.uniform(low=-1, high=1, size=(4 ** 4, env.action_space.n))
−1以上1未満の一様乱数を
(4 ** 4, env.action_space.n)個つくるという意味になります。
....。
sizeの内容が(4 ** 4, env.action_space.n)で意味不明です。
4 ** 4は4の4乗を表しており
前回、離散値ってなんや?で説明したように
・カート位置
・カート速度
・棒の角度
・棒の角速度
の4つの値を4つの領域に分けたことを表す4の4乗です。
env.action_space.nはこのゲームで、有効なactionを表していて
今回やっているCartPoleでは右に移動するか左に移動するかの2択なので
2になります。
ここまでわかると
縦が4の4乗の256、横が2の表が作られているのが
イメージできるのではないでしょうか?
やっと1行目を理解できました...。
ここまでで
・Qテーブルという表のようなものを作る
が終わりました。
ここからは
・フィードバックで得られる値を4の4乗の256の離散値で表す関数をつくる
に入ります。
まずは
def bins(clip_min, clip_max, num):
return np.linspace(clip_min, clip_max, num + 1)[1:-1]
です。
・defで関数を宣言!
Pythonではdefを使って
returnの後に書いた値を戻す
関数を宣言することができます。
今回の場合は
np.linspace(clip_min, clip_max, num + 1)[1:-1]という動きをする
bins(clip_min, clip_max, num)という関数を作っているのがわかります。
・np.linspace(始点, 終点, 何分割)[スライス]
np.linspaceを使うと等差数列がつくれます。
1、2、3、4、5、6、7、8、9、10は
公差が1の等差数列です。
np.linspaceに続く値では始点、終点、何分割するかを示すことができて
[]の値ではどれくらいの値を切り取るかを指定できます。
今回の場合
np.linspace(clip_min, clip_max, num + 1)[1:-1]
clip_minを始点、clip_maxを終点とするnum + 1分割した等差数列のうち
1〜−1を切り取るという意味になります。
clipは報酬の値を示すっぽいのですが
よくわからないので進めていくうちに理解できるのを待ちます。
つまり、ここでは
def bins(clip_min, clip_max, num):
return np.linspace(clip_min, clip_max, num + 1)[1:-1]
clip_minを始点、clip_maxを終点とするnum + 1分割した等差数列のうち
1〜−1を切り取りとった値を戻す
bins(clip_min, clip_max, num)という関数を定義しました。
ラストがこちらです!
def digitize_state(observation):
cart_pos, cart_v, pole_angle, pole_v = observation
digitized = [np.digitize(cart_pos, bins=bins(-2.4, 2.4, 4)),
np.digitize(cart_v, bins=bins(-3.0, 3.0, 4)),
np.digitize(pole_angle, bins=bins(-0.5, 0.5, 4)),
np.digitize(pole_v, bins=bins(-2.0, 2.0, 4))]
return sum([x * (4 ** i) for i, x in enumerate(digitized)])
まずはdigitizedというリストをつくっています。
・observationをバラバラに
cart_pos, cart_v, pole_angle, pole_v = observation
observationは観測したデータを表していて
それをcart_pos, cart_v, pole_angle, pole_v
つまりカート位置、カート速度、棒の角度、棒の角速度に分けています。
・digitize関数って何や?
digitize関数ではある値がどの範囲に入るか?を求めることができます。
np.digitize(値, 範囲)を意味していて
np.digitize(cart_pos, bins=bins(-2.4, 2.4, 4))
の場合cart_pos(=カート位置)が
-2.4,から2.4,を4つに分けた範囲のどの範囲に入るかを示します。
(先ほど作ったbins(clip_min, clip_max, num)が使われていますね!)
ちなみに4つの範囲のどれに入るのかは
0〜3の値で返してくれます。
他の3つも同じように働くとすると
digitized = [np.digitize(cart_pos, bins=bins(-2.4, 2.4, 4)),
np.digitize(cart_v, bins=bins(-3.0, 3.0, 4)),
np.digitize(pole_angle, bins=bins(-0.5, 0.5, 4)),
np.digitize(pole_v, bins=bins(-2.0, 2.0, 4))]
は[0,1,1,3]のような0〜3が入るリストの形になることがわかります。
続いてreturnの中を見てみます。
・enumerate()
enumerate関数というものでforと一緒に使うことで
何番目に要素が入っているかと一緒に表示することができます。
list = ['1', '2', '3']
for i, name x enumerate(list)
この場合iが何番目かxがリストの中身を示すので
0,1
1,2
2,3
のように0番目が1というように示されます。
今回の場合はlistが
digitized = [np.digitize(cart_pos, bins=bins(-2.4, 2.4, 4)),
np.digitize(cart_v, bins=bins(-3.0, 3.0, 4)),
np.digitize(pole_angle, bins=bins(-0.5, 0.5, 4)),
np.digitize(pole_v, bins=bins(-2.0, 2.0, 4))]
の部分になります。
これを踏まえて
return sum([x * (4 ** i) for i, x in enumerate(digitized)])
を見ていきます。
・sum([式 変数)])
sum関数は合計する関数なのですが
今回のように好きと変数が入っている場合は変わります。
今回の場合は
式=x * (4 ** i)
変数=for i, x in enumerate(digitized)
のようになっていて
digitizedのリストからi番目のxという要素かという変数を
x * (4 ** i)の式に代入したものの合計を返します。
この値は0〜255になって
4の4乗の256の離散値のどれかになっていることがわかります。
ここまでで
・フィードバックで得られる値を4の4乗の256の離散値で表す関数をつくる
ことができ
・Qテーブルという表のようなものを作る
・フィードバックで得られる値を4の4乗の256の離散値で表す関数をつくる
の2つができました!
めっちゃむずかった...。
でも初めは絶対無理だと思っていた
複雑なコードも基本のコードの組み合わせで
一つずつ読んでいけば理解できることがわかりました。
めっちゃ根気が必要ですが...。
次はこの関数を使って
強化学習を入れるはずです!!
参考にしたサイトはこちらです。
最後まで読んでいただきありがとうございました。
ぷもんでした!