見出し画像

未来永劫 回帰による推論をし続けるためのPythonプログラム Ver. 3.0


今回は、回帰による推論をし続けるためのPythonプログラムを紹介します

前回の記事では、回帰により日経平均株価を未来永劫予測するAIモデルの作成について紹介しました。

詳細は、下記の記事を参照ください

実際にAIモデルによる回帰を繰り返すには、AIモデルが出力した推論の結果(予測値)をAIモデルに入力することでAIモデルに次の予測値を出力させ、さらにその予測値をAIモデルに入力し、と行っていく必要があります。

もちろん、手作業でもできるのですが、あまりに面倒な作業となるため、これを行うPythonプログラムを作成しました。

Pythonプログラムのソースコード

回帰による推論をし続けるためのPythonプログラムのソースコードは下記となります。

'''
本プログラムの主な機能は以下の2つです
  1. 回帰(Regression)向けAIモデルの推論を繰り返し実行します
  2. AIモデルの回帰学習向けに株価データに基づいた学習/評価データを作成します

事前にNeural Network Libraries(NNabla)のインストールが必要です
'''
from nnabla.utils import nnp_graph
import sys
import numpy as np
import pandas as pd
import pandas_datareader.data as web
import datetime
import yfinance as yf
import argparse
import talib as ta


# Remove exception date from stock data
def FuncRemoveExceptionDate(Data):
    TmpData = Data.copy()

    # Exception date list
    ExceptionDateList = [
        '1997-07-20', # 海の日
        '1999-07-20',
        '1999-09-15',
        '1999-09-23',
        '1999-10-11',
        '1999-11-03',
        '1999-11-23',
        '1999-12-23',
        '1999-12-31',
        '2000-01-03',
        '2000-01-10',
        '2000-02-11',
        '2000-03-20',
        '2000-05-03',
        '2000-05-04',
        '2000-05-05',
        '2000-07-20',
        '2000-09-15',
        '2000-10-09',
        '2000-11-03',
        '2000-11-23',
        '2001-01-01',
        '2001-01-02',
        '2001-01-03',
        '2001-01-08',
        '2001-02-12',
        '2001-03-20',
        '2001-04-30',
        '2001-05-03',
        '2001-05-04',
        '2001-07-20',
        '2001-09-24',
        '2001-10-08',
        '2001-11-23',
        '2001-12-24',
        '2001-12-31',
        '2002-01-01',
        '2002-01-02',
        '2002-01-03',
        '2002-01-14',
        '2002-02-11',
        '2002-03-21',
        '2002-04-29',
        '2002-05-03',
        '2002-05-06',
        '2002-09-16',
        '2002-09-23',
        '2002-10-14',
        '2002-11-04',
        '2002-12-23',
        '2002-12-31',
        '2003-01-01',
        '2003-01-02',
        '2003-01-03',
        '2003-01-13',
        '2003-02-11',
        '2003-03-21',
        '2003-04-29',
        '2003-05-05',
        '2003-07-21',
        '2003-09-15',
        '2003-09-23',
        '2003-10-13',
        '2003-11-03',
        '2003-11-24',
        '2003-12-23',
        '2003-12-31',
        '2004-01-01',
        '2004-01-02',
        '2004-01-12',
        '2004-02-11',
        '2004-04-29',
        '2004-05-03',
        '2004-05-04',
        '2004-05-05',
        '2004-07-19',
        '2004-09-20',
        '2004-09-23',
        '2004-10-11',
        '2004-11-03',
        '2004-11-23',
        '2004-12-23',
        '2004-12-31',
        '2005-01-03',
        '2005-01-10',
        '2005-02-11',
        '2005-03-21',
        '2005-04-29',
        '2005-05-03',
        '2005-05-04',
        '2005-05-05',
        '2005-07-18',
        '2005-09-19',
        '2005-09-23',
        '2005-10-10',
        '2005-11-03',
        '2005-11-23',
        '2005-12-23',
        '2006-01-02',
        '2006-01-03',
        '2006-01-09',
        '2006-03-21',
        '2006-05-03',
        '2006-05-04',
        '2006-05-05',
        '2006-07-17',
        '2006-09-18',
        '2006-10-09',
        '2006-11-03',
        '2006-11-23',
        '2009-09-21',
        '2017-07-17',
        '2017-08-11',
        '2017-09-18',
        '2017-10-09',
        '2017-11-03',
        '2017-11-23',
        '2018-01-01',
        '2018-01-02',
        '2018-01-03',
        '2018-01-08',
        '2018-02-12',
        '2018-03-21',
        '2018-04-30',
        '2018-05-03',
        '2018-05-04',
        '2018-07-16', # 海の日
        '2018-09-17',
        '2018-09-24',
        '2018-10-08',
        '2018-11-23',
        '2018-12-24',
        '2018-12-31',
        '2020-10-01' # 東証システム障害により全銘柄の売買を終日停止
    ]

    for TmpDay in ExceptionDateList:
        TmpData = TmpData[TmpData.index != TmpDay]

    return TmpData


# Stock data download from Stooq
def FuncDLStockDataStooq(Symbol, StartDate, EndDate):
    # データ(行単位)削除フラグ
    RemoveDateFlag = False

    # シンボル変換
    if 'N225' == Symbol:
        TmpSymbol = '^NKX'
        RemoveDateFlag = True
    elif 'TOPIX' == Symbol: TmpSymbol = '^TPX'
    elif 'DOW' == Symbol: TmpSymbol = '^DJI'
    elif 'SP500' == Symbol: TmpSymbol = '^SPX'
    elif 'NAS' == Symbol: TmpSymbol = '^NDQ'
    #elif 'FTSE' == Symbol: TmpSymbol = '^UK100' # NG
    #elif 'USD' == Symbol: TmpSymbol = 'USDJPY' # NG
    #elif 'EUR' == Symbol: TmpSymbol = 'EURJPY' # NG
    else:
        TmpSymbol = Symbol[0] + '.JP'
        RemoveDateFlag = True

    # Stooqから株価データをダウンロード
    Data = web.DataReader(TmpSymbol, 'stooq', StartDate, EndDate)

    # インデックス(日付け)を昇順に並べ替え
    Data = Data.sort_index()
    
    # Volumeの列を削除
    if 'Volume' in Data.columns.values: Data = Data.drop('Volume', axis = 1)

    # 例外の日付けの行を削除
    if RemoveDateFlag: Data = FuncRemoveExceptionDate(Data)

    # カラム(列)名を変更
    if not ('N225' == Symbol or type(Symbol) == list):
        Data = Data.rename(
            columns = {
                'Open': Symbol,
                'High': Symbol,
                'Low': Symbol,
                'Close': Symbol
            })

    return Data


# Stock data download from Yahoo
def FuncDLStockDataYahoo(Symbol, StartDate, EndDate):
    # データ(行単位)削除フラグ
    RemoveDateFlag = False

    # シンボル変換
    if 'N225' == Symbol:
        TmpSymbol = '^N225'
        RemoveDateFlag = True
    #elif 'TOPIX' == Symbol: TmpSymbol = '998405' # NG
    elif 'DOW' == Symbol: TmpSymbol = '^DJI'
    elif 'SP500' == Symbol: TmpSymbol = '^GSPC'
    elif 'NAS' == Symbol: TmpSymbol = '^IXIC'
    elif 'FTSE' == Symbol: TmpSymbol = '^FTSE'
    elif 'HSI' == Symbol: TmpSymbol = '^HSI' # 香港ハンセン指数
    elif 'SHA' == Symbol: TmpSymbol = '000001.SS' # 上海総合指数
    elif 'VIX' == Symbol: TmpSymbol = '^VIX'
    elif 'USD' == Symbol: TmpSymbol = 'USDJPY=X'
    elif 'EUR' == Symbol: TmpSymbol = 'EURJPY=X'
    elif 'CNY' == Symbol: TmpSymbol = 'CNYJPY=X' # 元(中国)
    elif 'IRX' == Symbol: TmpSymbol = '^IRX' # 米13週国債
    elif 'TNX' == Symbol: TmpSymbol = '^TNX' # 米10年国債
    else:
        TmpSymbol = Symbol[0] + '.T'
        RemoveDateFlag = True

    # Yahooは最終日が1日前にズレるため、補正が必要
    EndDate4Yahoo = (pd.to_datetime(EndDate) + datetime.timedelta(days = 1)).strftime('%Y-%m-%d')

    # Yahooから株価データをダウンロード
    Data = yf.download(TmpSymbol, StartDate, EndDate4Yahoo)

    # Adj Closeの列を削除
    Data = Data.drop('Adj Close', axis = 1)

    # Volumeの列を削除
    if 'Volume' in Data.columns.values: Data = Data.drop('Volume', axis = 1)

    # 例外の日付けの行を削除
    if RemoveDateFlag: Data = FuncRemoveExceptionDate(Data)

    # カラム(列)名を変更
    if not ('N225' == Symbol or type(Symbol) == list):
        Data = Data.rename(
            columns = {
                'Open': Symbol,
                'High': Symbol,
                'Low': Symbol,
                'Close': Symbol
            })

    return Data


# Stock data download
def FuncDLStockData(Symbol, StartDate, EndDate):
    # Stooqから株価データをダウンロード
    DataStooq = FuncDLStockDataStooq(Symbol, StartDate, EndDate)

    # Yahooから株価データをダウンロード
    DataYahoo = FuncDLStockDataYahoo(Symbol, StartDate, EndDate)

    # Yahooの株価データを基準に、Stooqの株価データを合成
    Data = pd.concat([DataYahoo, DataStooq])

    # 重複したインデックス(日付け)に対して最初のデータを残す
    Data = Data[~Data.index.duplicated(keep = 'first')]

    # インデックス(日付け)を昇順に並べ替え
    Data = Data.sort_index()

    return Data


# SMA
def FuncSMA(ColName, Data):
    RefCol = 'Close'

    WinShort: int = 5
    WinMiddle: int = 25
    WinLong: int = 75

    SMA1 = Data[RefCol].rolling(window = WinShort).mean().to_frame(ColName)
    SMA2 = Data[RefCol].rolling(window = WinMiddle).mean().to_frame(ColName)
    SMA3 = Data[RefCol].rolling(window = WinLong).mean().to_frame(ColName)

    return pd.concat([Data, SMA1, SMA2, SMA3], axis = 1)


# Bollinger Band
def FuncBB(ColName, Data):
    RefCol = 'Close'

    WinBB: int = 20

    SMABB = Data[RefCol].rolling(window = WinBB).mean()
    StdDevBB = Data[RefCol].rolling(window = WinBB).std(ddof = 0)

    BBM2Sig = (SMABB - 2 * StdDevBB).to_frame(ColName)
    BBM1Sig = (SMABB - StdDevBB).to_frame(ColName)
    BBSMA = SMABB.to_frame(ColName)
    BBP1Sig = (SMABB + StdDevBB).to_frame(ColName)
    BBP2Sig = (SMABB + 2 * StdDevBB).to_frame(ColName)

    return pd.concat([Data, BBM2Sig, BBM1Sig, BBSMA, BBP1Sig, BBP2Sig], axis = 1)


# MACD
def FuncMACD(ColName, Data):
    RefCol = 'Close'

    SpanShort: int = 12
    SpanLong: int = 26
    WinSignal: int = 9

    EMAShort = Data[RefCol].ewm(span = SpanShort).mean()
    EMALong = Data[RefCol].ewm(span = SpanLong).mean()

    # MACD
    MACD = (EMAShort - EMALong).to_frame(ColName)

    # MACD Signal
    MACDSig = MACD[ColName].rolling(window = WinSignal).mean().to_frame(ColName)

    return pd.concat([Data, MACD, MACDSig], axis = 1)


# 一目均衡表
def FuncICHIMOKU(ColName, Data, NoLagFlag = False):
    HighCol = 'High'
    LowCol = 'Low'
    CloseCol = 'Close'

    WinBase: int = 26
    WinConv: int = 9
    Span1Shift: int = 25 # 26?
    WinSpan2: int = 52
    Span2Shift: int = 25 # 26?
    DelayShift: int = -25 # -26?

    # 基準線
    MaxBaseLine = Data[HighCol].rolling(WinBase).max()
    MinBaseLine = Data[LowCol].rolling(WinBase).min()

    BaseLineData = ((MaxBaseLine + MinBaseLine) / 2).to_frame(ColName)

    # 転換線
    MaxConvLine = Data[HighCol].rolling(WinConv).max()
    MinConvLine = Data[LowCol].rolling(WinConv).min()

    ConvLineData = ((MaxConvLine + MinConvLine) / 2).to_frame(ColName)

    # 先行スパン1
    Span1Data = ((BaseLineData[ColName] + ConvLineData[ColName]) / 2).shift(Span1Shift).to_frame(ColName)

    # 先行スパン2
    MaxSpan2Line = Data[HighCol].rolling(WinSpan2).max()
    MinSpan2Line = Data[LowCol].rolling(WinSpan2).min()

    Span2Data = ((MaxSpan2Line + MinSpan2Line) / 2).shift(Span2Shift).to_frame(ColName)

    # 遅行スパン
    LagginSpanData = Data[CloseCol].shift(DelayShift).to_frame(ColName)

    # 遅行スパンのNaNに最後の終値をコピー
    LastCloseRow: int =  len(LagginSpanData) + DelayShift - 1
    for i in range(len(LagginSpanData) + DelayShift, len(LagginSpanData)):
        LagginSpanData.iat[i, 0] = LagginSpanData.iat[LastCloseRow, 0]

    if NoLagFlag:
        # 遅行スパンなし
        return pd.concat([Data, BaseLineData, ConvLineData, Span1Data, Span2Data], axis = 1)
    else:
        # 遅行スパンあり
        return pd.concat([Data, BaseLineData, ConvLineData, Span1Data, Span2Data, LagginSpanData], axis = 1)


# 一目均衡表 遅行スパンなし
def FuncICHIMOKUNoLag(ColName, Data):
    return FuncICHIMOKU(ColName, Data, True)


# RSI
def FuncRSI(ColName, Data):
    RefCol = 'Close'

    SpanRSI: int = 14
    WinSignal: int = 9

    # RSI
    RSI = ta.RSI(Data[RefCol], timeperiod = SpanRSI).to_frame(ColName)

    # RSI Signal
    RSISig = RSI[ColName].rolling(window = WinSignal).mean().to_frame(ColName)

    return pd.concat([Data, RSI, RSISig], axis = 1)


# Stochastics
def FuncSTOCH(ColName, Data):
    HighCol = 'High'
    LowCol = 'Low'
    CloseCol = 'Close'

    WinK: int = 9
    WinD: int = 3
    WinSlowD: int = 3

    # %K
    MaxData = Data[HighCol].rolling(window = WinK).max()
    MinData = Data[LowCol].rolling(window = WinK).min()
    KData = 100 * (Data[CloseCol] - MinData) / (MaxData - MinData)
    KData = KData.to_frame(ColName)

    # %D
    DData = KData[ColName].rolling(window = WinD).mean().to_frame(ColName)

    # Slow%D
    SlowDData = DData[ColName].rolling(window = WinSlowD).mean().to_frame(ColName)

    return pd.concat([Data, KData, DData, SlowDData], axis = 1)


# 説明変数およびRNN向けにデータを横に並べる関数
def DataCopyShift(CopyNum, Data):
    # 出力用のDataFrame変数
    OutData = Data.copy()

    for i in range(CopyNum - 1):
        # Dataをコピー
        TmpData = Data.copy()

        # TmpDataを上方向にi + 1行シフト(i = 0, 1, ...)
        TmpData = TmpData.shift(-(i + 1))

        # OutDataの右側に結合
        OutData = pd.concat([OutData, TmpData], axis = 1)

    # NaNを含む行を削除
    OutData = OutData.dropna()

    return OutData


# 説明変数を作成
# 出力はDataFrame形式
def MakeExpData(EnList, FcList, Day, RNN, Data):
    # 制御定数の定義
    I_Name: int = 0
    I_Func: int = 1
    I_ENum: int = 2

    # 出力用のDataFrame変数
    OutData = Data.copy()

    for EID in EnList:
        # 関数を実行
        for Func in FcList:
            if EID == Func[I_Name]:
                OutData = Func[I_Func](Func[I_Name], OutData)

    # 説明変数向けに指定分だけ各データを横に並べる処理
    if 1 < Day: OutData = DataCopyShift(Day, OutData)

    # RNN向けに指定分だけ各データを横に並べる処理
    if 1 < RNN: OutData = DataCopyShift(RNN, OutData)

    return OutData


# 説明変数と目的変数から次の説明変数を作成
# 出力はDataFrame形式
def MakeExpVariable(ExpVar, ObjVar):
    # ndarrayをDataFrameに変換 + 転置
    TmpObj = pd.DataFrame(ObjVar).T

    # ExpVarの右隣にObjVarを結合
    Data = pd.concat([ExpVar, TmpObj], axis = 1)

    # DataFrameを右方向に4つシフト
    Data = Data.shift(-4, axis = 1)

    # NaNを含む列を削除
    Data = Data.dropna(how = 'all', axis = 1)

    return Data


def main():
    # コマンドライン引数の処理
    parser = argparse.ArgumentParser(description = 'Execute inference using trained AI models by Python API')

    # 学習/評価データ作成用オプション
    parser.add_argument('--s', required = False, default = '1986-01-01', help = 'Start date(Default: 1986-01-01)')
    parser.add_argument('--e', required = False, default = '2024-01-31', help = 'End date(Default: 2024-01-31)')
    parser.add_argument('--vnum', required = False, nargs = 1, type = int, default = 250, help = 'Number of rows for validation data(Default: 250)')
    parser.add_argument('--t', required = False, default = 'training.csv', help = 'Training data file name(Default: training.csv)')
    parser.add_argument('--v', required = False, default = 'validation.csv', help = 'Validation data file name(Default: validation.csv)')
    parser.add_argument('--debug', required = False, action = 'store_true', help = 'Debug mode')

    # 推論実行用オプション
    parser.add_argument('--nnp', required = False, help = 'NNP file')
    parser.add_argument('--rep', required = False, type = int, default = 1, help = 'Repeat count(Default: 1)')

    # 共通オプション
    parser.add_argument('--code', required = False, nargs = 1, help = 'Stock code number')
    parser.add_argument('--day', required = False, type = int, default = 1, help = 'Prameters for explanatory variables(Default: 1)')
    parser.add_argument('--elm', required = False, nargs = '+', help = 'Element to add to training data(Option: SMA BB MACD RSI STOCH)')
    parser.add_argument('--rnn', required = False, type = int, default = 1, help = 'Repeat parameter for RNN(Default: 1)')

    args = parser.parse_args()


    # print文に対する出力制御
    pd.set_option('display.max_rows', 50)
    pd.set_option('display.max_columns', 1000)
    pd.set_option("display.max_colwidth", 1000)
    pd.set_option('display.width', 10000)


    # 学習/評価データ向け期間 or 推論のきっかけとなる日付けの設定
    # 推論を行う場合、重要なのは--eオプションで指定した日付
    DayStart = pd.to_datetime(args.s).strftime('%Y-%m-%d')
    DayEnd = pd.to_datetime(args.e).strftime('%Y-%m-%d')


    # 指定期間の表示
    print('Start day:', DayStart, file = sys.stderr)
    print('End day:', DayEnd, file = sys.stderr)


    # 有効データの開始日を補正(-nnpオプション指定時のみ)
    # 目的: SMA算出に必要な最長75日分の有効な株価データを確保するため
    # 補正の基準は--eオプションで指定された年月日で、補正する日数は200日
    TmpDayStart = (pd.to_datetime(DayEnd) - datetime.timedelta(days = 200)).strftime('%Y-%m-%d')
    if not args.nnp: TmpDayStart = DayStart


    # 有効データの最終日を補正
    # 目的: --eオプションで指定した日の説明変数に対して目的変数(翌営業日の株価)を確保するため
    # 補正の基準は--eオプションで指定された年月日で、補正する日数はday + rnn + 20日
    TmpDayEnd = (pd.to_datetime(DayEnd) + datetime.timedelta(days = args.day + args.rnn + 20)).strftime('%Y-%m-%d')


    # 株価データを取得
    if not args.code:
        # 日経平均株価をダウンロード
        StockData = FuncDLStockData('N225', TmpDayStart, TmpDayEnd)
    else:
        # --codeオプションで指定された銘柄の株価をダウンロード
        StockData = FuncDLStockData(args.code, TmpDayStart, TmpDayEnd)


    # 出力用の変数を用意
    OutData = StockData.copy()


    # 各テクニカル分析指標のイネーブル設定
    EnableList = []

    if args.elm: EnableList = args.elm
    print('Training data element:', ', '.join(EnableList), file = sys.stderr)


    # 関数リスト
    # 名前, 関数名, 列数(要素数)
    FuncList = [
        # SMA(Short, Middle, Long)
        ['SMA', FuncSMA, 3],

        # Bollinger Band(-2Sig, -Sig, SMA, +Sig, +2Sig)
        ['BB', FuncBB, 5],

        # MACD(MACD, MACDSignal)
        ['MACD', FuncMACD, 2],

        # 一目均衡表(基準線, 転換線, 先行スパン1, 先行スパン2, 遅行スパン)
        #['ICHIMOKU', FuncICHIMOKU, 5],

        # 一目均衡表(基準線, 転換線, 先行スパン1, 先行スパン2)
        #['ICHIMOKUNOLAG', FuncICHIMOKUNoLag, 4],

        # RSI(RSI)
        ['RSI', FuncRSI, 2],

        # Stochastics(K, D, SlowD)
        ['STOCH', FuncSTOCH, 3]
    ]


    # 説明変数向けに指定分だけ各データを横に並べる数
    DayNum = int(args.day)
    if 1 < DayNum: print('Day params:', DayNum, file = sys.stderr)

    # RNN向けに指定分だけ各データを横に並べる数
    RNNNum = int(args.rnn)
    if 1 < RNNNum: print('RNN params:', RNNNum, file = sys.stderr)

    # 学習データに要素を追加
    OutData = MakeExpData(EnableList, FuncList, DayNum, RNNNum, OutData)


    # Neural Network Console用のヘッダを設定
    # 説明変数用のヘッダを用意
    NNCHeader = 'x__' + pd.Series(range(0, len(OutData.columns)), dtype = 'str')


    # 回帰向け目的変数を作成
    if not args.nnp:
        # 列名(Open, High, Low, Close)の列のみ抽出
        TmpData = OutData.loc[:, ['Open', 'High', 'Low', 'Close']]

        # 列名(Open, High, Low, Close)が重複している列の最右の列を抽出
        TmpData = TmpData.loc[:, ~TmpData.columns.duplicated(keep = 'last')]

        # TmpDataを上方向に1行シフト
        TmpData = TmpData.shift(-1)

        # OutDataの右側に結合
        OutData = pd.concat([OutData, TmpData], axis = 1)

        # NaNを含む行を削除
        OutData = OutData.dropna()


    # 有効データの最終日をピッタリ合わせる処理
    OutData = OutData[:DayEnd]
    if 2 < (args.day + args.rnn): OutData = OutData.iloc[:-(args.day + args.rnn - 2)]


    # 回帰向け目的変数用のヘッダを追加
    if not args.nnp:
        # 回帰向け目的変数用のヘッダを用意
        NNCHeaderObj = 'y__' + pd.Series(range(0, 4), dtype = 'str')
        NNCHeader = pd.concat([NNCHeader, NNCHeaderObj])


    # 学習データのヘッダを上書き
    if not args.debug: OutData.columns = [NNCHeader]


    # データの分割処理
    if not args.nnp:
        if len(OutData) < args.vnum:
            print('\nError!!')
            print('Too few rows of data:', len(OutData))
            print('Please adjust with --vnum option or --s and --e option')
        else:
            # 学習データをトレーニングデータとバリデーションデータに分割
            print('Number of rows for validation data:', args.vnum)

            TrainingData = OutData[: -args.vnum]

            print('Training data period:', TrainingData.index[0], '-', TrainingData.index[-1])

            ValidationData = OutData[len(OutData) - args.vnum :]

            print('Validation data period:', ValidationData.index[0], '-', ValidationData.index[-1])


    # Debug mode
    if args.debug:
        if not args.nnp and args.vnum < len(OutData):
            print('\nTraining data')
            print(TrainingData)

            print('\nValidation data')
            print(ValidationData)
        elif args.nnp:
            print('\nPrediction data')
            print(OutData)


    # 各種データをcsvファイルに出力
    # ただし、Debug mode時はファイル出力しない
    if not args.nnp and not args.debug and args.vnum < len(OutData):
        # トレーニングデータをcsvファイルに出力
        print('Training data csv file:', args.t)
        TrainingData.to_csv(args.t, index = False, float_format = '%.4f')

        # バリデーションデータをcsvファイルに出力
        print('Validation data csv file:', args.v)
        ValidationData.to_csv(args.v, index = False, float_format = '%.4f')


    if not args.nnp: print('\nDone.')


    # 推論実行ルーチン
    if (not args.debug and args.nnp):
        # 標準エラー出力
        print('NNP file:', args.nnp, file = sys.stderr)
        print('Repeat count:', args.rep, file = sys.stderr)

        # 推論用データの基データを用意
        TmpStockData = StockData.copy()
        TmpStockData = TmpStockData[:DayEnd]

        # 学習済みニューラルネットワークの読み込み
        nnp = nnp_graph.NnpLoader(args.nnp)

        # 推論用ニューラルネットワークの取得
        graph = nnp.get_network('MainRuntime', batch_size = 1)

        # 入力変数xの取得
        x = list(graph.inputs.values())[0]

        # 出力変数yの取得
        y = list(graph.outputs.values())[0]

        # 推論用の入力データを代入
        PredictData = OutData.tail(1)

        # 推論処理
        for i in range(args.rep):
            # 入力変数xに値を代入
            x.d = PredictData

            # 推論実行
            y.forward()

            # 推論結果をカッコ無しで表示
            print(*y.d[0], sep = " ")

            # DatetimeIndex向けの変数
            TmpDateIndex = (pd.to_datetime(DayEnd) + datetime.timedelta(days = i + 1)).strftime('%Y-%m-%d')

            # ndarrayをDataFrameに変換
            TmpData = pd.DataFrame(y.d, index = [TmpDateIndex], columns = ['Open', 'High', 'Low', 'Close'])

            # 推論用データの基データを更新
            TmpStockData = pd.concat([TmpStockData, TmpData])

            # 推論用データに要素を追加
            PredictData = MakeExpData(EnableList, FuncList, DayNum, RNNNum, TmpStockData)

            # 推論用の入力データを更新
            PredictData = PredictData.tail(1)


if __name__ == '__main__':
    main()
  • お断り

    • 上記のPythonプログラムは、私のPython環境(Ver. 3.9.7)でのみ動作確認を行っています

    • 上記のPythonプログラムに対しては、私が思いつく限りのデバッグを行いましたが、バグが残っている可能性があります

    • バグは、発見次第修正します

    • 修正情報の発信は、本記事を更新することで行います

上記のPythonプログラムをファイルに保存する場合は、ファイル名の拡張子を"py"とします。

本記事では、上記のPythonプログラムをfilename.pyに保存したものとして以後の説明を行います。

Pythonプログラムのオプション一覧(Ver. 3.0)

Ver. 3.0で使用可能なオプションは、下記の通りです。

  • オプションの詳細

    • --code

      • 銘柄を指定(デフォルトは指定なし、つまり、日経平均株価が選択される)

      • ファーストリテイリング(9983)を指定する場合は"--code 9983"

    • --day

      • 説明変数に使用する株価(始値、高値、安値、終値)の日数を指定(デフォルトは1日)

      • 説明変数に5日分の株価を使用する場合は"--day 5"

    • --elm

      • 学習データおよび評価データに含ませる要素を指定(デフォルトは指定なし)

      • 指定方法は"--elm SMA"とか"--elm BB MACD"とか

        • 移動平均(5, 25, 75日): SMA

        • ボリンジャーバンド: BB

        • MACD: MACD

        • 一目均衡表: ICHIMOKU

        • RSI: RSI

        • ストキャスティクス: STOCH

      • 指定する順番に従って説明変数内での並びも変わる

    • --rnn

      • Neural Network ConsoleのRNN用にn日分のデータを1行ベクトルとして作成(デフォルトは1日)

      • 指定方法は"--rnn 5"とか

    • --s

      • 株価データの開始年月日を指定(デフォルトは1986-01-01)

      • 指定方法は"--s 2001-8-10"とか"--s 2010-03-05"とか

    • --e

      • 株価データの終了年月日を指定(デフォルトは2024-01-31)

      • 指定方法は--sオプションと同じ

    • --vnum

      • 評価データの日数を指定(デフォルトは250日)

      • 指定方法は"--vnum 220"とか

    • --t

      • 学習データを出力するファイル名を指定(デフォルトはtraining.csv)

      • 指定方法は"--t train.csv"とか

    • --v

      • 評価データを出力するファイル名を指定(デフォルトはvalidation.csv)

      • 指定方法は"--v val.csv"とか

    • --debug

      • デバッグモードを有効化(デフォルトは無効)

      • 指定方法は"--debug"

Pythonプログラムの実行方法(Ver. 3.0)

Ver. 3.0では、主に2つの機能をサポートしています。

  • Ver. 3.0でのサポート機能

    • 回帰向け学習/評価データの出力

    • 学習済みAIモデルに対する回帰による推論の実行

回帰向け学習/評価データの出力方法

対象期間を2000年1月1日から2020年12月31日までとし、学習1回分の説明変数(1行ベクトル)に日経平均株価の20日分を割り当てる場合は、次のように実行します。

> python filename.py --s 2000-1-1 --e 2020-12-31 --day 20
Start day: 2000-01-01
End day: 2020-12-31
[*********************100%%**********************]  1 of 1 completed
Training data element:
Day params: 20
Number of rows for validation data: 250
Training data period: 2000-01-04 00:00:00 - 2019-11-21 00:00:00
Validation data period: 2019-11-22 00:00:00 - 2020-12-03 00:00:00
Training data csv file: training.csv
Validation data csv file: validation.csv

Done.

学習データはtraining.csvに、評価データはvalidation.csvに保存されます。

学習済みAIモデルに対する回帰による推論の実行方法

学習済みAIモデルに対して、回帰による推論を2021年1月1日から10営業日分だけ行う場合は、以下のように実行します。

> python filename.py --e 2020-12-31 --day 20 --nnp results.nnp --rep 10
Start day: 1986-01-01
End day: 2020-12-31
[*********************100%%**********************]  1 of 1 completed
Training data element:
Day params: 20
NNP file: .\LearnedModel\N225_4Affine_Regression_20days.nnp
Repeat count: 10
27468.389 27589.0 27329.082 27430.32
27489.898 27609.23 27348.06 27449.324
27513.941 27637.643 27378.771 27481.953
27480.236 27606.291 27340.605 27444.637
27481.992 27602.352 27343.955 27446.834
27457.81 27583.81 27322.256 27425.377
27452.602 27572.947 27316.732 27416.625
27454.674 27578.266 27311.387 27414.586
27481.186 27606.787 27333.621 27438.18
27485.945 27615.44 27335.28 27444.432

上記において、results.nnpはNeural Network Consoleが生成した学習済みAIモデルのファイルです。

注意が必要な点として、--eオプションで指定する日付けがあります。

回帰による推論を2021年1月1日から始める場合は、その前営業日の説明変数を学習済みAIモデルに与える必要があります。

このため、前営業日の日付を--eオプションで指定して下さい。

ちなみに、 --sオプションで日付けを指定する必要はありません (デフォルトのままで問題ありません)。

学習済みAIモデルが推論した結果は、日経平均株価(予測値)が始値、高値、安値、終値の順に日単位で出力されます。


2024年6月20日 Ver. 2.0 説明変数が複数日に対応しました(Ver. 3.0で--csvオプションは廃止されました)

--csvオプションで指定する推論のきっかけとなる始めの一回分の日経平均株価の情報(説明変数)が複数日に対応しました。

例えば、5日分の日経平均株価の情報を与える場合、--csvオプションで指定するファイルのデータフォーマットは以下のようになります。

推論 きっかけ データ フォーマット 複数日 対応
推論のきっかけとなるデータのフォーマット(複数日対応)

1日分の日経平均株価の情報が始値、高値、安値、終値であるため、5日分で20個のデータとなります。

左から順に、1日目の始値、高値、安値、終値、2日目、…と続きます。


2024年6月26日 Ver. 3.0 説明変数がテクニカル指標に対応し、かつ、学習/評価データの出力機能をサポートしました

今回のバージョンアップは大変更となりました。

  • 変更点一覧

    • 説明変数に次の指定が可能

      • 日経平均株価以外の株価(--codeオプション)

      • テクニカル指標(--elmオプション)

      • 複数日の株価(--dayオプション)

      • Neural Network Console向けRNNフォーマット出力(--rnnオプション)

    • 学習/評価データの出力が可能

    • --csvオプションの廃止

      • --eオプションを使用した方法に変更(詳細は「Pythonプログラムの実行方法」を参照ください)


いいなと思ったら応援しよう!