見出し画像

明日の日経平均株価は陽線 or 陰線? AI向け学習データおよび評価データを作成するPythonプログラム


AI向け学習データおよび評価データを作成するPythonプログラム

翌営業日の日経平均株価が陽線か、あるいは、陰線かを予測するAI向けに学習データおよび評価データを作成するPythonプログラムを用意しました。

ちなみに、学習データおよび評価データは、Neural Network Console向けとなります。

Neural Network Consoleは、SONYが提供しているグラフィカルなUIを備えたAI開発ツールです。

現状、Neural Network Consoleのクラウド版のサービスについては、2024年12月25日で終了するとのアナウンスがされています。

私は、Windows版(無料)を使用しています。

私が作成したPythonプログラムのソースコードは下記となります。

# -*- coding: utf-8 -*-

import pandas as pd
import pandas_datareader.data as web
import datetime
import yfinance as yf
import argparse
import talib as ta


# デバッグ用設定
pd.set_option('display.max_rows', 100)


# Remove exception date from stock data
def FuncRemoveExceptionDate(Data):
    TmpData = Data.copy()

    # Exception date list
    ExceptionDateList = [
        '1997-07-20', # 海の日
        '1999-07-20',
        '1999-09-15',
        '1999-09-23',
        '1999-10-11',
        '1999-11-03',
        '1999-11-23',
        '1999-12-23',
        '1999-12-31',
        '2000-01-03',
        '2000-01-10',
        '2000-02-11',
        '2000-03-20',
        '2000-05-03',
        '2000-05-04',
        '2000-05-05',
        '2000-07-20',
        '2000-09-15',
        '2000-10-09',
        '2000-11-03',
        '2000-11-23',
        '2001-01-01',
        '2001-01-02',
        '2001-01-03',
        '2001-01-08',
        '2001-02-12',
        '2001-03-20',
        '2001-04-30',
        '2001-05-03',
        '2001-05-04',
        '2001-07-20',
        '2001-09-24',
        '2001-10-08',
        '2001-11-23',
        '2001-12-24',
        '2001-12-31',
        '2002-01-01',
        '2002-01-02',
        '2002-01-03',
        '2002-01-14',
        '2002-02-11',
        '2002-03-21',
        '2002-04-29',
        '2002-05-03',
        '2002-05-06',
        '2002-09-16',
        '2002-09-23',
        '2002-10-14',
        '2002-11-04',
        '2002-12-23',
        '2002-12-31',
        '2003-01-01',
        '2003-01-02',
        '2003-01-03',
        '2003-01-13',
        '2003-02-11',
        '2003-03-21',
        '2003-04-29',
        '2003-05-05',
        '2003-07-21',
        '2003-09-15',
        '2003-09-23',
        '2003-10-13',
        '2003-11-03',
        '2003-11-24',
        '2003-12-23',
        '2003-12-31',
        '2004-01-01',
        '2004-01-02',
        '2004-01-12',
        '2004-02-11',
        '2004-04-29',
        '2004-05-03',
        '2004-05-04',
        '2004-05-05',
        '2004-07-19',
        '2004-09-20',
        '2004-09-23',
        '2004-10-11',
        '2004-11-03',
        '2004-11-23',
        '2004-12-23',
        '2004-12-31',
        '2005-01-03',
        '2005-01-10',
        '2005-02-11',
        '2005-03-21',
        '2005-04-29',
        '2005-05-03',
        '2005-05-04',
        '2005-05-05',
        '2005-07-18',
        '2005-09-19',
        '2005-09-23',
        '2005-10-10',
        '2005-11-03',
        '2005-11-23',
        '2005-12-23',
        '2006-01-02',
        '2006-01-03',
        '2006-01-09',
        '2006-03-21',
        '2006-05-03',
        '2006-05-04',
        '2006-05-05',
        '2006-07-17',
        '2006-09-18',
        '2006-10-09',
        '2006-11-03',
        '2006-11-23',
        '2009-09-21',
        '2017-07-17',
        '2017-08-11',
        '2017-09-18',
        '2017-10-09',
        '2017-11-03',
        '2017-11-23',
        '2018-01-01',
        '2018-01-02',
        '2018-01-03',
        '2018-01-08',
        '2018-02-12',
        '2018-03-21',
        '2018-04-30',
        '2018-05-03',
        '2018-05-04',
        '2018-07-16', # 海の日
        '2018-09-17',
        '2018-09-24',
        '2018-10-08',
        '2018-11-23',
        '2018-12-24',
        '2018-12-31',
        '2020-10-01' # 東証システム障害により全銘柄の売買を終日停止
    ]

    for TmpDay in ExceptionDateList:
        TmpData = TmpData[TmpData.index != TmpDay]

    return TmpData


# Stock data download from Stooq
def FuncDLStockDataStooq(Symbol, StartDate, EndDate):
    # データ(行単位)削除フラグ
    RemoveDateFlag = False

    # シンボル変換
    if 'N225' == Symbol:
        TmpSymbol = '^NKX'
        RemoveDateFlag = True
    elif 'TOPIX' == Symbol: TmpSymbol = '^TPX'
    elif 'DOW' == Symbol: TmpSymbol = '^DJI'
    elif 'SP500' == Symbol: TmpSymbol = '^SPX'
    elif 'NAS' == Symbol: TmpSymbol = '^NDQ'
    #elif 'FTSE' == Symbol: TmpSymbol = '^UK100' # NG
    #elif 'USD' == Symbol: TmpSymbol = 'USDJPY' # NG
    #elif 'EUR' == Symbol: TmpSymbol = 'EURJPY' # NG
    else:
        TmpSymbol = Symbol[0] + '.JP'
        RemoveDateFlag = True

    # Stooqから株価データをダウンロード
    Data = web.DataReader(TmpSymbol, 'stooq', StartDate, EndDate)

    # Datetimeインデックスを日付けのみに変更
    Data.index = Data.index.date

    # インデックス(日付け)を昇順に並べ替え
    Data = Data.sort_index()
    
    # Volumeの列を削除
    if 'Volume' in Data.columns.values: Data = Data.drop('Volume', axis = 1)

    # 例外の日付けの行を削除
    if RemoveDateFlag: Data = FuncRemoveExceptionDate(Data)

    # カラム(列)名を変更
    if not ('N225' == Symbol or type(Symbol) == list):
        Data = Data.rename(
            columns = {
                'Open': Symbol,
                'High': Symbol,
                'Low': Symbol,
                'Close': Symbol
            })

    return Data


# Stock data download from Yahoo
def FuncDLStockDataYahoo(Symbol, StartDate, EndDate):
    # データ(行単位)削除フラグ
    RemoveDateFlag = False

    # シンボル変換
    if 'N225' == Symbol:
        TmpSymbol = '^N225'
        RemoveDateFlag = True
    #elif 'TOPIX' == Symbol: TmpSymbol = '998405' # NG
    elif 'DOW' == Symbol: TmpSymbol = '^DJI'
    elif 'SP500' == Symbol: TmpSymbol = '^GSPC'
    elif 'NAS' == Symbol: TmpSymbol = '^IXIC'
    elif 'FTSE' == Symbol: TmpSymbol = '^FTSE'
    elif 'HSI' == Symbol: TmpSymbol = '^HSI' # 香港ハンセン指数
    elif 'SHA' == Symbol: TmpSymbol = '000001.SS' # 上海総合指数
    elif 'VIX' == Symbol: TmpSymbol = '^VIX'
    elif 'USD' == Symbol: TmpSymbol = 'USDJPY=X'
    elif 'EUR' == Symbol: TmpSymbol = 'EURJPY=X'
    elif 'CNY' == Symbol: TmpSymbol = 'CNYJPY=X' # 元(中国)
    elif 'IRX' == Symbol: TmpSymbol = '^IRX' # 米13週国債
    elif 'TNX' == Symbol: TmpSymbol = '^TNX' # 米10年国債
    else:
        TmpSymbol = Symbol[0] + '.T'
        RemoveDateFlag = True

    # Yahooは最終日が1日前にズレるため、補正が必要
    EndDate4Yahoo = (pd.to_datetime(EndDate) + datetime.timedelta(days = 1)).strftime('%Y-%m-%d')

    # Yahooから株価データをダウンロード
    Data = yf.download(TmpSymbol, StartDate, EndDate4Yahoo)

    # Datetimeインデックスを日付けのみに変更
    Data.index = Data.index.date

   # 株価データのマルチカラムをシングルカラムに変更
    if 1 < Data.columns.nlevels:
        TmpColumns = []

        for ColName in Data.columns.values:
            TmpColumns.append(ColName[0])

        Data.columns = TmpColumns

    # Adj Closeの列を削除
    if 'Adj Close' in Data.columns.values: Data = Data.drop('Adj Close', axis = 1)

    # Volumeの列を削除
    if 'Volume' in Data.columns.values: Data = Data.drop('Volume', axis = 1)

    # 例外の日付けの行を削除
    if RemoveDateFlag: Data = FuncRemoveExceptionDate(Data)

    # カラム(列)名を変更
    if not ('N225' == Symbol or type(Symbol) == list):
        Data = Data.rename(
            columns = {
                'Open': Symbol,
                'High': Symbol,
                'Low': Symbol,
                'Close': Symbol
            })

    return Data


# Stock data download
def FuncDLStockData(Symbol, StartDate, EndDate):
    # Stooqから株価データをダウンロード
    DataStooq = FuncDLStockDataStooq(Symbol, StartDate, EndDate)

    # Yahooから株価データをダウンロード
    DataYahoo = FuncDLStockDataYahoo(Symbol, StartDate, EndDate)

    # Yahooの株価データを基準に、Stooqの株価データを合成
    Data = pd.concat([DataYahoo, DataStooq])

    # 重複したインデックス(日付け)に対して最初のデータを残す
    Data = Data[~Data.index.duplicated(keep = 'first')]

    # インデックス(日付け)を昇順に並べ替え
    Data = Data.sort_index()

    return Data


# SMA
def FuncSMA(ColName, Data):
    RefCol = 'Close'

    WinShort: int = 5
    WinMiddle: int = 25
    WinLong: int = 75

    SMA1 = Data[RefCol].rolling(window = WinShort).mean().to_frame(ColName)
    SMA2 = Data[RefCol].rolling(window = WinMiddle).mean().to_frame(ColName)
    SMA3 = Data[RefCol].rolling(window = WinLong).mean().to_frame(ColName)

    return pd.concat([SMA1, SMA2, SMA3], axis = 1)


# Bollinger Band
def FuncBB(ColName, Data):
    RefCol = 'Close'

    WinBB: int = 20

    SMABB = Data[RefCol].rolling(window = WinBB).mean()
    StdDevBB = Data[RefCol].rolling(window = WinBB).std(ddof = 0)

    BBM2Sig = (SMABB - 2 * StdDevBB).to_frame(ColName)
    BBM1Sig = (SMABB - StdDevBB).to_frame(ColName)
    BBSMA = SMABB.to_frame(ColName)
    BBP1Sig = (SMABB + StdDevBB).to_frame(ColName)
    BBP2Sig = (SMABB + 2 * StdDevBB).to_frame(ColName)

    return pd.concat([BBM2Sig, BBM1Sig, BBSMA, BBP1Sig, BBP2Sig], axis = 1)


# MACD
def FuncMACD(ColName, Data):
    RefCol = 'Close'

    SpanShort: int = 12
    SpanLong: int = 26
    WinSignal: int = 9

    EMAShort = Data[RefCol].ewm(span = SpanShort).mean()
    EMALong = Data[RefCol].ewm(span = SpanLong).mean()

    # MACD
    MACD = (EMAShort - EMALong).to_frame(ColName)

    # MACD Signal
    MACDSig = MACD[ColName].rolling(window = WinSignal).mean().to_frame(ColName)

    return pd.concat([MACD, MACDSig], axis = 1)


# 一目均衡表
def FuncICHIMOKU(ColName, Data, NoLagFlag = False):
    HighCol = 'High'
    LowCol = 'Low'
    CloseCol = 'Close'

    WinBase: int = 26
    WinConv: int = 9
    Span1Shift: int = 25 # 26?
    WinSpan2: int = 52
    Span2Shift: int = 25 # 26?
    DelayShift: int = -25 # -26?

    # 基準線
    MaxBaseLine = Data[HighCol].rolling(WinBase).max()
    MinBaseLine = Data[LowCol].rolling(WinBase).min()

    BaseLineData = ((MaxBaseLine + MinBaseLine) / 2).to_frame(ColName)

    # 転換線
    MaxConvLine = Data[HighCol].rolling(WinConv).max()
    MinConvLine = Data[LowCol].rolling(WinConv).min()

    ConvLineData = ((MaxConvLine + MinConvLine) / 2).to_frame(ColName)

    # 先行スパン1
    Span1Data = ((BaseLineData[ColName] + ConvLineData[ColName]) / 2).shift(Span1Shift).to_frame(ColName)

    # 先行スパン2
    MaxSpan2Line = Data[HighCol].rolling(WinSpan2).max()
    MinSpan2Line = Data[LowCol].rolling(WinSpan2).min()

    Span2Data = ((MaxSpan2Line + MinSpan2Line) / 2).shift(Span2Shift).to_frame(ColName)

    # 遅行スパン
    LagginSpanData = Data[CloseCol].shift(DelayShift).to_frame(ColName)

    # 遅行スパンのNaNに最後の終値をコピー
    LastCloseRow: int =  len(LagginSpanData) + DelayShift - 1
    for i in range(len(LagginSpanData) + DelayShift, len(LagginSpanData)):
        LagginSpanData.iat[i, 0] = LagginSpanData.iat[LastCloseRow, 0]

    if NoLagFlag:
        # 遅行スパンなし
        return pd.concat([BaseLineData, ConvLineData, Span1Data, Span2Data], axis = 1)
    else:
        # 遅行スパンあり
        return pd.concat([BaseLineData, ConvLineData, Span1Data, Span2Data, LagginSpanData], axis = 1)


# 一目均衡表 遅行スパンなし
def FuncICHIMOKUNoLag(ColName, Data):
    return FuncICHIMOKU(ColName, Data, True)


# RSI
def FuncRSI(ColName, Data):
    RefCol = 'Close'

    SpanRSI: int = 14
    WinSignal: int = 9

    # RSI
    RSI = ta.RSI(Data[RefCol], timeperiod = SpanRSI).to_frame(ColName)

    # RSI Signal
    RSISig = RSI[ColName].rolling(window = WinSignal).mean().to_frame(ColName)

    return pd.concat([RSI, RSISig], axis = 1)


# Stochastics
def FuncSTOCH(ColName, Data):
    HighCol = 'High'
    LowCol = 'Low'
    CloseCol = 'Close'

    WinK: int = 9
    WinD: int = 3
    WinSlowD: int = 3

    # %K
    MaxData = Data[HighCol].rolling(window = WinK).max()
    MinData = Data[LowCol].rolling(window = WinK).min()
    KData = 100 * (Data[CloseCol] - MinData) / (MaxData - MinData)
    KData = KData.to_frame(ColName)

    # %D
    DData = KData[ColName].rolling(window = WinD).mean().to_frame(ColName)

    # Slow%D
    SlowDData = DData[ColName].rolling(window = WinSlowD).mean().to_frame(ColName)

    return pd.concat([KData, DData, SlowDData], axis = 1)


# ラベルを作成(二値分類用)
# 翌営業日の日経平均株価のローソク足が陽線か陰線か
def FuncLabel2ClassCandleStick(Data):
    # Openの列を抽出
    TmpOpen = Data['Open']

    # 最後のOpenの列を抽出
    if 0 < len(TmpOpen.dtypes): TmpOpen = TmpOpen.iloc[:, -1]

    # Closeの列を抽出
    TmpClose = Data['Close']

    # 最後のCloseの列を抽出
    if 0 < len(TmpClose.dtypes): TmpClose = TmpClose.iloc[:, -1]

    # 終値と始値の差分を抽出
    LabelData = TmpClose - TmpOpen

    # SeriesをDataFrameに変換
    LabelData = LabelData.to_frame('Label')

    # LabelDataを上方向に1行シフト
    LabelData = LabelData.shift(-1)

    # 終値と始値の差分が0より大きいならラベル1、0以下はラベル0
    LabelData = (0 < LabelData) * 1

    return LabelData


# 説明変数およびRNN向けにデータを横に並べる関数
def DataCopyShift(CopyNum, Data):
    # 出力用のDataFrame変数
    OutData = Data.copy()

    for i in range(CopyNum - 1):
        # Dataをコピー
        TmpData = Data.copy()

        # TmpDataを上方向にi + 1行シフト(i = 0, 1, ...)
        TmpData = TmpData.shift(-(i + 1))

        # OutDataの右側に結合
        OutData = pd.concat([OutData, TmpData], axis = 1)

    # NaNを含む行を削除
    OutData = OutData.dropna()

    return OutData


# 標準化(Standardization)
def FuncStd(Data):
    TmpMeanData = Data.mean(axis = None)
    TmpStdData = Data.values.std(ddof = 0)
    TmpData = (Data - TmpMeanData) / TmpStdData

    return TmpData, TmpMeanData, TmpStdData


# リスト実行用関数
def FuncListExec(EnList, FuncList, StdMode, Args):
    # リスト実行用インデックス
    I_Name: int = 0
    I_Func: int = 1

    # 出力用の空のデータフレーム
    OutData = pd.DataFrame()

    for EID in EnList:
        for Func in FuncList:
            if EID == Func[I_Name]:
                # 関数を実行
                if type(Args) is list:
                    FuncData = Func[I_Func](EID, *Args)
                else:
                    FuncData = Func[I_Func](EID, Args)

                # 個別に標準化を行う
                if not StdMode:
                    # NaNを含む行を削除
                    FuncData = FuncData.dropna()

                    # データの標準化
                    TmpData = FuncStd(FuncData)

                    # 標準化されたデータを代入
                    FuncData = TmpData[0]

                    # 標準化パラメータの表示
                    print('Standardization(' + Func[I_Name] + '): Mean = ', TmpData[1], ', Std = ', TmpData[2])

                # OutDataとFuncDataのインデックスに基づいてデータを横に結合
                OutData = OutData.join(FuncData, how = 'outer')

    return OutData


def main():
    # コマンドライン引数の処理
    parser = argparse.ArgumentParser(description = 'Stock data to Training data tool for AI model')
    parser.add_argument('--csv', required = False, help = 'Stock data file name(csv format)')
    parser.add_argument('--code', required = False, nargs = 1, help = 'Stock code number')
    parser.add_argument('--day', required = False, type = int, default = 1, help = 'Prameters for explanatory variables(Default: 1)')
    parser.add_argument('--dl', required = False, help = 'csv file name(Download stock data to csv file)')
    parser.add_argument('--elm', required = False, nargs = '+', help = 'Element to add to training data(Option: SMA BB MACD ICHIMOKU RSI STOCH TOPIX DOW SP500 NAS FTSE HSI SHA VIX USD EUR CNY IRX TNX)')
    parser.add_argument('--rnn', required = False, type = int, default = 1, help = 'Repeat parameter for RNN(Default: 1)')
    parser.add_argument('--s', required = False, default = '1990-01-01', help = 'Start date(Default: 1990-01-01)')
    parser.add_argument('--e', required = False, default = '2023-12-31', help = 'End date(Default: 2023-12-31)')
    parser.add_argument('--vnum', required = False, type = int, default = 250, help = 'Number of rows for validation data(Default: 250)')
    parser.add_argument('--t', required = False, default = 'training.csv', help = 'Training data file name(Default: training.csv)')
    parser.add_argument('--v', required = False, default = 'validation.csv', help = 'Validation data file name(Default: validation.csv)')
    parser.add_argument('--nolabel', required = False, action = 'store_true', help = 'No label output mode(Default output file: nolabel.csv)')
    parser.add_argument('nolabelfile', nargs = '?', default = 'nolabel.csv')
    parser.add_argument('--nostd', required = False, action = 'store_true', help = 'No standardization(Default: Excecute standardization)')
    parser.add_argument('--debug', required = False, action = 'store_true', help = 'Debug mode')

    args = parser.parse_args()


    # 指定期間の表示
    DayStart = pd.to_datetime(args.s).strftime('%Y-%m-%d')
    DayEnd = pd.to_datetime(args.e).strftime('%Y-%m-%d')


    # --dayオプションに対する補正処理
    DayStartCorrect = pd.to_datetime(DayStart)
    if 1 < args.day:
        # --dayオプションで指定した日数の2倍 + 10日を引く
        DayStartCorrect = DayStartCorrect - datetime.timedelta(days = (2 * args.day + 10))

    # datetimeをstrに変換
    DayStartCorrect = DayStartCorrect.strftime('%Y-%m-%d')


    print('Start day:', DayStartCorrect)
    print('End day:', DayEnd)


    # 株価データの取得先を選択(csv file or Website)
    if args.csv:
        print('csv file:', args.csv)

        # csvファイルのデータをDataFrameに入力
        StockData = pd.read_csv(args.csv, encoding = 'utf-8')

        # DataFrameのインデックスをDateに変更し、オリジナルのDataFrameも更新
        StockData.set_index('Date', inplace = True)

        # 指定期間のみ抽出
        StockData = StockData[DayStart : DayEnd]

        # Volumeの列を削除
        if 5 == len(StockData.columns): StockData = StockData.drop('Volume', axis = 1)

        # 日付けを昇順に並べ替え(csvファイルによる株価データの使用を想定)
        StockData = StockData.sort_index(ascending = True)
    else:
        if not args.code:
            # 日経平均株価をダウンロード
            StockData = FuncDLStockData('N225', DayStartCorrect, DayEnd)
        else:
            # --codeオプションで指定された銘柄の株価をダウンロード
            StockData = FuncDLStockData(args.code, DayStartCorrect, DayEnd)


    # 学習データの作成処理
    if args.dl:
        print('Output stock data to', args.dl)

        # 株価データをcsvファイルに出力
        StockData.to_csv(args.dl)
    else:
        # 各テクニカル分析指標のイネーブル設定
        EnableList = []

        if args.elm: EnableList = args.elm

        print('Training data element:', ', '.join(EnableList))


        # 株価データの標準化
        if not args.nostd:
            TmpData = FuncStd(StockData) # データの標準化
            OutData = TmpData[0]

            # 標準化パラメータの表示
            print('Standardization(Stock): Mean = ', TmpData[1], ', Std = ', TmpData[2])
        else:
            OutData = StockData

            print('No standardization!!')


        # 関数リスト
        # 名前, 関数名
        FuncList = [
            # SMA(Short, Middle, Long)
            ['SMA', FuncSMA],

            # Bollinger Band(-2Sig, -Sig, SMA, +Sig, +2Sig)
            ['BB', FuncBB],

            # MACD(MACD, MACDSignal)
            ['MACD', FuncMACD],

            # 一目均衡表(基準線, 転換線, 先行スパン1, 先行スパン2, 遅行スパン)
            ['ICHIMOKU', FuncICHIMOKU],

            # 一目均衡表(基準線, 転換線, 先行スパン1, 先行スパン2)
            ['ICHIMOKUNOLAG', FuncICHIMOKUNoLag],

            # RSI(RSI, RSI Signal)
            ['RSI', FuncRSI],

            # Stochastics(K, D, SlowD)
            ['STOCH', FuncSTOCH]
        ]


        # ダウンロードリスト
        # 名前, 関数名
        DlList = [
            # TOPIX
            ['TOPIX', FuncDLStockDataStooq],

            # DOW
            ['DOW', FuncDLStockData],

            # S&P500
            ['SP500', FuncDLStockData],

            # NASDAQ
            ['NAS', FuncDLStockData],

            # VIX
            ['VIX', FuncDLStockDataYahoo],

            # FTSE100
            ['FTSE', FuncDLStockDataYahoo],

            # 香港ハンセン指数
            ['HSI', FuncDLStockDataYahoo],

            # 上海総合指数
            ['SHA', FuncDLStockDataYahoo],

            # USDJPY
            ['USD', FuncDLStockDataYahoo],

            # EURJPY
            ['EUR', FuncDLStockDataYahoo],

            # 元(中国)
            ['CNY', FuncDLStockDataYahoo],

            # 米13週国債
            ['IRX', FuncDLStockDataYahoo],

            # 米10年国債
            ['TNX', FuncDLStockDataYahoo]
        ]


        # 指定された関数を実行
        TmpData = FuncListExec(EnableList, FuncList, args.nostd, StockData)

        # OutDataのインデックスを基準にデータを横に結合
        OutData = OutData.join(TmpData)

        # 各種データのダウンロードを実行
        TmpData = FuncListExec(EnableList, DlList, args.nostd, [DayStartCorrect, DayEnd])

        # OutDataのインデックスを基準にデータを横に結合
        # 欠損値は前営業日の値をコピー
        # ★ 1日目の欠損値は前日の値をコピーできないため削除される
        # ★ 削除を避けるには--sオプションで開始日を前に調整すること
        OutData = OutData.join(TmpData).ffill()


        # --dayオプションの処理
        print('Day params:', args.day)

        # str型をdatetime.date型に変換
        TmpDayStart = datetime.datetime.strptime(DayStart, '%Y-%m-%d').date()

        # --sオプションで指定した開始日の行番号を取得
        RowStart = OutData.index.get_loc(OutData[TmpDayStart <= OutData.index].head(1).index[0])

        # --dayオプションの指定を考慮した範囲を抽出
        OutData = OutData[RowStart - args.day + 1:]

        # --dayオプションの指定分だけ各データを横に並べる処理
        if 1< args.day:
            OutData = DataCopyShift(args.day, OutData)


        # RNN向けに指定分だけ各データを横に並べる処理
        if 1 < args.rnn:
            print('RNN params:', args.rnn)
            OutData = DataCopyShift(args.rnn, OutData)


        # ラベルを作成する処理
        if not args.nolabel:
            # ラベルデータを取得
            TmpData = FuncLabel2ClassCandleStick(OutData)

            # ラベルデータを最も左の列として結合
            OutData = pd.concat([OutData, TmpData], axis = 1)

            # NaNを含む行を削除
            OutData = OutData.dropna()
        else:
            print('No label output!!')


        # Neural Network Console用のヘッダを設定
        # ラベルありなし用ヘッダ数を調整
        HeaderAdj: int = 1
        if args.nolabel: HeaderAdj = 0

        # 説明変数用のヘッダを用意
        NNCHeader = 'x__' + pd.Series(range(0, len(OutData.columns) - HeaderAdj), dtype = 'str')

        # 二値分類向け目的変数(ラベル)用のヘッダを追加
        if not args.nolabel:
            NNCHeader[len(NNCHeader)] = 'y:label;D;U'

        # 学習データのヘッダを上書き
        if not args.debug: OutData.columns = [NNCHeader]

        # データの分割処理
        if not args.nolabel:
            if len(OutData) < args.vnum:
                print('\nError!!')
                print('Too few rows of data:', len(OutData))
                print('Please adjust with --vnum option or --s and --e option')
            else:
                # 学習データをトレーニングデータとバリデーションデータに分割
                print('Number of rows for validation data:', args.vnum)

                TrainingData = OutData[: -args.vnum]

                print('Training data period:', TrainingData.index[0], '-', TrainingData.index[-1])

                ValidationData = OutData[len(OutData) - args.vnum :]

                print('Validation data period:', ValidationData.index[0], '-', ValidationData.index[-1])
        else:
            # ラベル無しの場合はデータを分割しない
            print('No label data period:', OutData.index[0], '-', OutData.index[-1])


        if not args.debug:
            # 各種データをcsvファイルに出力
            if not args.nolabel:
                if args.vnum < len(OutData):
                    # トレーニングデータをcsvファイルに出力
                    print('Training data csv file:', args.t)
                    TrainingData.to_csv(args.t, index = False, float_format = '%.4f')

                    # バリデーションデータをcsvファイルに出力
                    print('Validation data csv file:', args.v)
                    ValidationData.to_csv(args.v, index = False, float_format = '%.4f')
            else:
                # ラベル無しデータをcsvファイルに出力
                print('No label data csv file:', args.nolabelfile)
                OutData.to_csv(args.nolabelfile, index = False, float_format = '%.4f')
        else:
            # Debug mode
            if not args.nolabel:
                print('\nTraining data')
                print(TrainingData)

                print('\nValidation data')
                print(ValidationData)
            else:
                print('\nNo label data')
                print(OutData)


    print('\n' + 'Done.')


if __name__ == '__main__':
    main()

私のPython環境(Python 3.12.3)でデバッグを行いましたが、まだ何らかのバグが残っている可能性はあります。
m(_ _)mオユルシクダサイ

Pythonプログラムで使用可能なオプション

Pythonプログラムで使用可能なオプションは、以下の通りです。

  • -h, --help: ヘルプメッセージが出力されます

  • --csv:  事前にダウンロードした株価データを使用して学習データおよび評価データを作成する場合に使用します(csvファイル名を指定)

  • --code: 日経平均株価以外の銘柄の株価データを使用して学習データおよび評価データを作成する場合に使用します(証券コードを指定)

  • --dl:  StooqやYahoo Financeからダウンロードした株価データを保存するcsvファイル名を指定します

  • --elm:  学習データおよび評価データの説明変数に含ませるテクニカル分析指標、等の要素を指定します(詳細は下記を参照)

  • --rnn:  Neural Network ConsoleのRNN用にn段のデータを1行ベクトルとして作成します(デフォルトは1段)

  • --s:  取得する株価データの開始年月日を指定します(デフォルトは1990-01-01)

  • --e:  取得する株価データの終了年月日を指定します(デフォルトは2023-12-31)

  • --vnum:  評価データの日数を指定します(デフォルトは250日)

  • --t:  学習データを保存するcsvファイル名を指定します(デフォルトはtraining.csv)

  • --v:  評価データを保存するcsvファイル名を指定します(デフォルトはvalidation.csv)

  • --nolabel:  学習済みAI向けに推論用のデータを保存する際に指定します(デフォルトはnolabel.csv)

  • --nostd:  学習データおよび評価データに対して標準化を行わない場合に指定します(デフォルトは標準化を行います)

  • --debug: デバッグモードを有効化します(デフォルトは無効)

また、--elmオプションで指定可能なテクニカル分析指標、等の要素は、以下の通りです。

  • --elmオプションで指定可能なテクニカル分析指標、等の種類と引数

    • テクニカル分析指標

      • 移動平均(5, 25, 75日): SMA

      • ボリンジャーバンド(±2σ, ±1σ, SMA): BB

      • MACD(MACD, MACD Signal): MACD

      • 一目均衡表(基準線、転換線、先行スパン1, 2): ICHIMOKU

      • RSI(RSI, RSI Signal): RSI

      • ストキャスティクス(%K %D, Slow%D): STOCH

    • 株価データ、等

      • 東証株価指数: TOPIX

      • ダウ平均株価: DOW

      • S&P500: SP500

      • ナスダック総合指数: NAS

      • FTSE100種総合株価指数: FTSE

      • 香港ハンセン株価指数: HSI

      • 上海総合指数: SHA

      • VIX指数: VIX

      • ドル円: USD

      • ユーロ円: EUR

      • 元円: CNY

      • 米13週国債金利: IRX

      • 米10年国債金利: TNX

--elmオプションでテクニカル分析指標、等を指定する場合に注意点があります。

それは、指定するパラメータの順序により、説明変数の並びも変わってしまうということです。

例えば、--elm RSI SMAと指定した場合、説明変数の並びは、以下のようになります。

  • 株価データ(Open, High, Low, Close), RSI(RSI, RSI Signal), SMA(5, 25, 75日)

Pythonプログラムを実行する際、--elmオプションで指定するパラメータの順序に制約はありませんが、説明変数の並びにはご注意ください。

ただし、Pythonプログラムの都合で、テクニカル分析指標が株価データ、等より先に並びます。

Pythonプログラムの実行方法

上記のPythonプログラムを保存したファイル名をfilename.pyとします。

例えば、下記の条件に基づいた学習データおよび評価データを作成する場合の実行方法は次の通りです。

  • 株価データとして日経平均株価を使用

  • 説明変数に追加するテクニカル分析指標、等の要素

    • 移動平均

    • MACD

    • ダウ平均株価

  • 株価データ取得期間: 2010年6月1日から2024年5月31日

> python filename.py --elm SMA MACD DOW --s 2010-6-1 --e 2024-5-31

Pythonプログラムを実行した結果、training.csvとvalidation.csvが作成されます。

また、学習済みAIの推論用データを作成する場合は、以下のオプションを指定します。

  • 学習データおよび評価データの作成時に指定したオプション(--elmの並びも同じ)

  • --s, --eオプションで適切な日時を指定

  • --nolabel


2024年11月19日 バグフィックス & 機能追加

バグフィックスと機能追加を行いました。

  • バグフィックスについて

    • --rnnオプションを使用すると、ラベルデータが適切に付与されないバグを修正しました

  • 機能追加について

    • --dayオプションを追加しました

--dayオプションは、説明変数としてn日分のデータを使用する場合に指定します。

例えば、3日分のデータを説明変数に使用する場合は、--day 3とします。

--dayオプションのデフォルト値は、1(つまり、1日)です。


2024年11月22日 バグフィックス

バグフィックスを行いました。

  • バグフィックスについて

    • --dayオプションを使用すると、説明変数が指定した日数分だけ並ばないバグを修正しました

いいなと思ったら応援しよう!