RNNをスクラッチで実装してみる③:新しいデータのクラス予測

前回までのRNN(Recurrent Neural Network)スクラッチ実装では、順伝播(フォワードパス)と逆伝播(バックプロパゲーション)の基本的な仕組みを学び、RNNモデルの学習過程を確認しました。今回は、学習が進んだRNNモデルに新しいデータを入力し、どのようにクラスを予測するかを紹介します。

1. RNNで新しいデータを入力する

まず、すでに学習したモデルに対して新しいデータを入力し、その結果をクラスとして表示させる方法を説明します。前回までで学習したRNNでは、アルファベット "A"~"E" という5クラスのデータを用いてモデルを訓練してきました。

新しいデータとして、["C", "D", "E", "A"] という文字列をモデルに入力し、それに対する予測を行います。

2. 新しい入力データのone-hotエンコーディング

まず、入力データをone-hotエンコードします。one-hotエンコードは、各クラスを0と1で表現し、どのクラスであるかを示すベクトルです。

new_input = np.array([["C", "D", "E", "A"]])
one_hot_new_input = string_to_one_hot(new_input)

print("New input one-hot encoded shape:", one_hot_new_input.shape)

これにより、["C", "D", "E", "A"] というデータがone-hotエンコードされます。例えば、"C"は [0, 0, 1, 0, 0] というベクトルで表されます。

3. RNNモデルに新しいデータを通す

次に、RNNの隠れ層の初期状態をゼロベクトルにリセットし、新しいデータを順にRNNモデルに通して予測を行います。

new_inputs = one_hot_new_input[0]  # (3, 5, 1) 3タイムステップがある

# 隠れ層の初期状態をリセット
a_t_minus_1 = np.zeros((hidden_size, 1))

# 予測を計算
S_new = np.zeros((new_inputs.shape[0], hidden_size, 1))
A_new = np.zeros((new_inputs.shape[0], hidden_size, 1))
O_new = np.zeros((new_inputs.shape[0], size, 1))
Yhat_new = np.zeros((new_inputs.shape[0], size, 1))

for time_step in range(new_inputs.shape[0]):
    inputs = new_inputs[time_step, :, :]
    if time_step == 0:
        S_new[time_step] = W @ a_t_minus_1 + U @ inputs + B
    else:
        S_new[time_step] = W @ A_new[time_step-1] + U @ inputs + B
    A_new[time_step] = np.tanh(S_new[time_step])  # 隠れ層のアクティベーション
    O_new[time_step] = V @ A_new[time_step] + C  # 出力
    Yhat_new[time_step] = softmax(O_new[time_step])  # softmax 適用

ここでは、各タイムステップ("C", "D", "E", "A")に対して、出力層で計算された結果をsoftmax関数を通じてクラスの確率として表します。

4. クラスに変換して結果を表示

最後に、出力された確率を最大値を持つインデックスに変換して、それがどのクラスに該当するかを決定します。

# クラスを表示
predicted_classes = np.zeros((new_inputs.shape[0], 1))
for time_step in range(new_inputs.shape[0]):
    predicted_classes[time_step] = np.argmax(Yhat_new[time_step])

# 予測クラスの表示
print("Predicted classes for new input [C, D, E, A]:")
print(predicted_classes)

上記のコードにより、モデルは新しい入力に対して次のクラス予測を行いました。

Predicted classes for new input [C, D, E, A]:
[[3.]
 [4.]
 [0.]
 [1.]]

結果の解釈

  • C に対して予測されたクラスは 3 です。これは、アルファベット D に対応します。

  • D に対して予測されたクラスは 4 で、これは E に対応します。

  • E に対して予測されたクラスは 0 で、これは A に対応します。

  • A に対して予測されたクラスは 1 で、これは B に対応します。

これは、RNNがシーケンスのパターンを学習し、次に続くクラスを予測できるようになったことを示しています。今回の例では、["C", "D", "E", "A"] に対して次に来る文字(クラス)がそれぞれ D, E, A, B であると予測しています。

今後の課題

今回のRNNモデルはシンプルなデータでの訓練結果ですが、より多様なシーケンスや複雑なデータに対しても性能を発揮できるよう、さらなる調整が必要です。例えば以下の点に注力して改善していきます。

  1. データセットの拡充:より多様で大きなデータセットを用意する。

  2. モデルのチューニング:隠れ層のサイズや学習率など、ハイパーパラメータの調整。

  3. 長期依存の学習:LSTMやGRUなど、長期依存を扱えるモデルの導入。

今後は、RNNの拡張版としてLSTM(Long Short-Term Memory)をスクラッチで実装し、長期依存関係の学習に挑戦してみたいと思います。

いいなと思ったら応援しよう!