見出し画像

Google Colabで時系列基盤モデルを試す①:Google timesfm


はじめに

Transformerアーキテクチャにテキストデータを大量に読み込ませたらある程度あらゆる場面で使えるモデルができたというのがGPTやBERTなどの言語のFoundation Model(基盤モデル)です。
それと同じ発想で、あらゆる時系列データを読み込ませたら、あらゆる場面で使える時系列モデルが作れるのではないかという発想で作ったのが時系列の基盤モデルになります。

HuggingFaceにある商用可能なライセンスの時系列基盤モデルを4つ試し、比較していきたいと思います。利用するデータはETTh1という電力変圧器温度に関する多変量時系列データセットです。事前学習にこのデータが含まれる可能性があるため、モデルの絶対的な評価に繋がらないことに注意してください。

  1. google/timesfm-1.0-200m (今回)

    • ダウンロード数:4.59k

    • モデルサイズ:200m

    • ライセンス:Apache-2.0

  2. AutonLab/MOMENT-1-large

    • ダウンロード数:5.79k

    • モデルサイズ:385m

    • ライセンス:MIT

  3. ibm-granite/granite-timeseries-ttm-v1

    • ダウンロード数:10.1k

    • モデルサイズ:805k (小さい!!)

    • ライセンス:Apache-2.0

  4. amazon/chronos-t5-large

    • ダウンロード数:256k (多い!!)

    • モデルサイズ:709m

    • ライセンス:Apache-2.0

6月2日時点でダウンロード数が少ない順に実施していきます。今回はGoogleのtimesfmです。他3つのモデルがDecoder-Encoder Architectureなのに対して、timesfmはDecoder Only Architectureなのが特徴的です。
timesfmはFine Tuningのコードが公開されていないので、推論のみを実施します。

1. 推論

ライブラリの準備
timesfmモデルを動かすためのライブラリtimesfmは、2024/6/2現在でARMアーキテクチャをサポートしておらず、Apple siliconのMacでは動作しないようです。

# library install
!pip install git+https://github.com/google-research/timesfm.git
!pip install utilsforecast

データ準備
あらかじめETTh1.csvを取得しておいてください。取得はこちらからできます。

import pandas as pd

# データ読み込み
# https://github.com/zhouhaoyi/ETDataset/blob/main/ETT-small/ETTh1.csv
df = pd.read_csv("ETTh1.csv")
print(len(df))
df.head(2)

データは以下のような形式です。OT(Oil Temperature)が目的変数となります。

データ加工
モデルに与える長さを512、予測する長さを96と今回はおきます。Fine Tuningに使うため、後ろから予測するデータをとっておきます。

import torch

context_length = 512
forecast_horizon = 96

# データセット分割
df_train = df.iloc[-(context_length+forecast_horizon):-forecast_horizon]
df_test = df.iloc[-forecast_horizon:]

# 形式の変更
train_tensor = torch.tensor(df_train[["HUFL", "HULL", "MUFL", "MULL", "LUFL", "LULL", "OT"]].values, dtype=torch.float)
train_tensor = train_tensor.t()
test_tensor = torch.tensor(df_test[["HUFL", "HULL", "MUFL", "MULL", "LUFL", "LULL", "OT"]].values, dtype=torch.float)
test_tensor = test_tensor.t()

モデルの取得
input_patch_len, output_patch_len, num_layers, model_dimsは固定でこの値です。
context_lenは予測する際に元とする時系列長で、512を最大とした好きな値を設定できます。horizon_lenが予測する時系列長でこれも好きな値が設定できますがcontext_lenを超えないことを推奨されています。

import timesfm

tfm = timesfm.TimesFm(
    context_len=context_length,
    horizon_len=forecast_horizon,
    input_patch_len=32,
    output_patch_len=128,
    num_layers=20,
    model_dims=1280,
    backend="gpu",
)
tfm.load_from_checkpoint(repo_id="google/timesfm-1.0-200m")

推論
freqには与える時系列データの頻度を{0, 1, 2}で入力します。

0(default):高頻度、長ホライズン時系列。日単位までの時系列に使用することを推奨する。
1: 中頻度時系列。週次および月次データに使用することを推奨する。
2: 低頻度、短ホライズン時系列。月次を超えるデータ、例えば四半期や年次のデータに使用することをお勧めします。

https://huggingface.co/google/timesfm-1.0-200m#perform-inference

point_forecastに予測結果が格納されます。
point_forecastはポイント予測(つまり1時間単位で1点の予測)であり、experimental_quantile_forecastは確率予測のためのデータとなります。確率予測はまだ正式にサポートされておらず現状はポイント予測のみを使います。

# 予測の実行
frequency_input = [0] * train_tensor.size(0)
point_forecast, experimental_quantile_forecast = tfm.forecast(
    train_tensor,
    freq=frequency_input,
)
forecast_tensor = torch.tensor(point_forecast)
quantile_tensor = torch.tensor(experimental_quantile_forecast)

推論結果の出力
予測結果としてOT(Oil Temperature)がみたいのでchannel_idxとして6を指定します。

import matplotlib.pyplot as plt

channel_idx = 6
time_index = 0

history = train_tensor[channel_idx, :].detach().numpy()
true = test_tensor[channel_idx, :].detach().numpy()
pred = forecast_tensor[channel_idx, :].detach().numpy()

plt.figure(figsize=(12, 4))

# Plotting the first time series from history
plt.plot(range(len(history)), history, label='History (512 timesteps)', c='darkblue')

# Plotting ground truth and prediction
num_forecasts = len(true)

offset = len(history)
plt.plot(range(offset, offset + len(true)), true, label='Ground Truth (96 timesteps)', color='darkblue', linestyle='--', alpha=0.5)
plt.plot(range(offset, offset + len(pred)), pred, label='Forecast (96 timesteps)', color='red', linestyle='--')

plt.title(f"ETTh1 (Hourly) -- (idx={time_index}, channel={channel_idx})", fontsize=18)
plt.xlabel('Time', fontsize=14)
plt.ylabel('Value', fontsize=14)
plt.legend(fontsize=14)
plt.show()

グラフは以下のようになりました。0-shotでも予測ができています。

ETTh1のOT(oil temperature)の予測結果

2. Fine Tuning

timesfmはFine Tuningのコードが公開されていませんでした。

3. 推論(Fine Tuning後)

timesfmはFine Tuningのコードが公開されていませんでした。

4. 結果

GoogleのTimesfmはFineTuningのコードがなかったので前後の比較はできませんでしたが、FTなしでもある程度正確に予測ができている気がします。

頻度(freq)を指定すること、512を最大として、入力と予測長を自由に指定できるところが特徴的だったと思います。
他の記事と見比べていただければと思いますが、性能としてはサイズ的に順当な性能といった感じに思えました。(全モデル事前学習データに含まれていないという前提で、となりますが。)

次回

参照


いいなと思ったら応援しよう!