超簡単Pythonで株価予測(GluonTS 利用)時系列予測
PythonでGluonTSを利用して25日先までの株価予測を超簡単に時系列予測(Amazon製)
Facebook製の同様ツールについては過去の投稿をどうぞ
1. ツールインストール
$ pip install mxnet~=1.7 gluonts pandas-datareader scikit-learn
2. ファイル作成
pred.py
from gluonts.dataset.common import ListDataset
from gluonts.model.deepar import DeepAREstimator
from gluonts.mx.trainer import Trainer
from gluonts.dataset.util import to_pandas
import matplotlib.pyplot as plt
import pandas_datareader as pdr
from sklearn.model_selection import train_test_split
training, test = train_test_split(
pdr.get_data_yahoo("AAPL", "2019-11-01", "2020-11-01")["Close"],
test_size=0.2,
shuffle=False,
)
training_data = ListDataset(
[{"start": training.index[0], "target": training}],
freq = "d"
)
estimator = DeepAREstimator(freq="d", prediction_length=25, trainer=Trainer(epochs=10))
predictor = estimator.train(training_data=training_data)
test_data = ListDataset(
[{"start": test.index[0], "target": test}],
freq = "d"
)
for test_entry, forecast in zip(test_data, predictor.predict(test_data)):
to_pandas(test_entry)[-150:].plot(figsize=(12, 5), linewidth=2)
forecast.plot(color='g')
plt.grid(which='both')
plt.legend(["observations", "median prediction", "90% confidence interval", "50% confidence interval"])
plt.savefig("pred.png")
3. 実行
$ python pred.py
以上、超簡単!