上がるの? 下がるの? 二値分類 AIを使って日経平均株価の予測に挑戦 過学習への対策は始まったばかり編
前回の記事より、今後の課題を振り返る
前回は、RNN LSTMの構造を使用して日経平均株価の予測に対する二値分類を行いました。
AIモデルの予測に関しては、Accuracy(分類精度)が50.29%であったため、未だ山勘レベルです。
しかし、私にとっては、ようやくスタートラインに立てたと感じた結果でした。
さて、前回の結果でAccuracy(分類精度)が低くなった原因として、過学習が挙げられます。
過学習とは、AIモデルが学習データで最適化されすぎた結果、条件が異なる他の評価データに対して学習した性能が十分に発揮されない状態です。
上記は、前回の結果における学習曲線を表しています。
Epoch(AIモデルが学習した回数)が増えるに従い、赤い実線で示された学習時のエラー(TRAINING ERROR)が低下しているのが確認できます。
一方で、Epochの増加に従い、赤い破線で示された評価時のエラー(VALIDATION ERROR)は増加しています。
この状態が過学習を表しています。
一般的に、過学習への対策は以下の通りです。
過学習への対策
学習データ(説明変数)の見直し
AIモデルの見直し
そこで、今回は、学習データの見直しを行うことにしました。
学習データの見直し
これまで使用してきた学習データのフォーマットは、下記の通りです。
これまでの学習データのフォーマット
1行7列のベクトルデータ
左から終値、始値、高値、安値、5日移動平均、25日移動平均、75日移動平均
RNNに使用する場合は、1行7列で表される1日分のデータを2日分横に並べて1行14列としている
今回使用する学習データのフォーマットは、下記の通りです。
今回の学習データのフォーマット
1日分のデータを左から終値、始値、高値、安値とする
5日分のデータを時系列に従い左から右に並べて1行20列のベクトルデータとする
RNN LSTMの過学習への対策のために用意した学習データの基データを下記に示します。
U列のラベルは、翌営業日の終値が当日の終値以上であれば1, そうでなければ0としています。
例えば、セルU2はIF(Q2<=Q3, 1, 0)としています。
今回は学習データのフォーマットを変更しましたので、これに伴いLSTMのパラメーターも修正する必要があります。
AIモデルに使用したRNN LSTMの構造図
学習データのフォーマットに合わせてパラメーターを修正したLSTMの構造図を下記に示します。
修正箇所は、下記の2カ所です。
AIモデルの修正箇所
InputのSizeを20に修正
ReshapeのOutShapeを5,4に修正
Reshapeで1行20列のベクトルデータを5行4列(5×(終値、始値、高値、安値))のベクトルデータに変換しています。
そして、RecurrentInputおよびRecurrentOutputで挟まれた層では、分割された1行4列(終値、始値、高値、安値)のベクトルデータに対して、処理を合わせて5回行います。
AIモデルの学習および評価を実行
修正したAIモデルに今回作成したRNN用の学習データを学習させた際の学習曲線を下記に示します。
前回の結果と同じく、AIモデルの学習が進むにつれてVALIDATION ERRORが増加する傾向となっていますが、途中でリセットされたような状況も確認できます。
学習の途中で何が起きたのかは分かっておりません。
続いて、今回の評価結果に対する混同行列を示します。
比較のため、下記に前回の評価結果に対する混同行列を示します。
僅かではありますが、Accuracy(分類精度)が50.29%から56.73%に改善していることが確認できます。
今後の課題
引き続き、過学習への対策を検討していきます。
過学習への対策
学習データ(説明変数)の見直し
AIモデルの見直し
学習データの見直しについては、移動平均やボリンジャーバンド、MACD, 等のテクニカル分析で使用される指標の追加を検討します。
AIモデルの見直しについては、中間層の多層化やCNN(Convolutional Neural Network)の使用も検討していきたいと考えています。
ただし、LSTMの構造に切り替えてからはAIモデルの学習時間が数時間に及ぶため、効率の良いやり方も検討したいと思います。
AIモデルのデータについて
今回作成したAIモデルのデータは、Googleドライブにて共有しています。
URL: https://drive.google.com/drive/folders/1eYkB4ob_VThhObaH3WqvalDDjem-UqUr?usp=drive_link
N225_LSTM_Affine100N_5Days.sdcproj
Neural Network Console用のプロジェクトファイル