RNNをスクラッチで実装してみる③:新しいデータのクラス予測
前回までのRNN(Recurrent Neural Network)スクラッチ実装では、順伝播(フォワードパス)と逆伝播(バックプロパゲーション)の基本的な仕組みを学び、RNNモデルの学習過程を確認しました。今回は、学習が進んだRNNモデルに新しいデータを入力し、どのようにクラスを予測するかを紹介します。
1. RNNで新しいデータを入力する
まず、すでに学習したモデルに対して新しいデータを入力し、その結果をクラスとして表示させる方法を説明します。前回までで学習したRNNでは、アルファベット "A"~"E" という5クラスのデータを用いてモデルを訓練してきました。
新しいデータとして、["C", "D", "E", "A"] という文字列をモデルに入力し、それに対する予測を行います。
2. 新しい入力データのone-hotエンコーディング
まず、入力データをone-hotエンコードします。one-hotエンコードは、各クラスを0と1で表現し、どのクラスであるかを示すベクトルです。
new_input = np.array([["C", "D", "E", "A"]])
one_hot_new_input = string_to_one_hot(new_input)
print("New input one-hot encoded shape:", one_hot_new_input.shape)
これにより、["C", "D", "E", "A"] というデータがone-hotエンコードされます。例えば、"C"は [0, 0, 1, 0, 0] というベクトルで表されます。
3. RNNモデルに新しいデータを通す
次に、RNNの隠れ層の初期状態をゼロベクトルにリセットし、新しいデータを順にRNNモデルに通して予測を行います。
new_inputs = one_hot_new_input[0] # (3, 5, 1) 3タイムステップがある
# 隠れ層の初期状態をリセット
a_t_minus_1 = np.zeros((hidden_size, 1))
# 予測を計算
S_new = np.zeros((new_inputs.shape[0], hidden_size, 1))
A_new = np.zeros((new_inputs.shape[0], hidden_size, 1))
O_new = np.zeros((new_inputs.shape[0], size, 1))
Yhat_new = np.zeros((new_inputs.shape[0], size, 1))
for time_step in range(new_inputs.shape[0]):
inputs = new_inputs[time_step, :, :]
if time_step == 0:
S_new[time_step] = W @ a_t_minus_1 + U @ inputs + B
else:
S_new[time_step] = W @ A_new[time_step-1] + U @ inputs + B
A_new[time_step] = np.tanh(S_new[time_step]) # 隠れ層のアクティベーション
O_new[time_step] = V @ A_new[time_step] + C # 出力
Yhat_new[time_step] = softmax(O_new[time_step]) # softmax 適用
ここでは、各タイムステップ("C", "D", "E", "A")に対して、出力層で計算された結果をsoftmax関数を通じてクラスの確率として表します。
4. クラスに変換して結果を表示
最後に、出力された確率を最大値を持つインデックスに変換して、それがどのクラスに該当するかを決定します。
# クラスを表示
predicted_classes = np.zeros((new_inputs.shape[0], 1))
for time_step in range(new_inputs.shape[0]):
predicted_classes[time_step] = np.argmax(Yhat_new[time_step])
# 予測クラスの表示
print("Predicted classes for new input [C, D, E, A]:")
print(predicted_classes)
上記のコードにより、モデルは新しい入力に対して次のクラス予測を行いました。
Predicted classes for new input [C, D, E, A]:
[[3.]
[4.]
[0.]
[1.]]
結果の解釈
C に対して予測されたクラスは 3 です。これは、アルファベット D に対応します。
D に対して予測されたクラスは 4 で、これは E に対応します。
E に対して予測されたクラスは 0 で、これは A に対応します。
A に対して予測されたクラスは 1 で、これは B に対応します。
これは、RNNがシーケンスのパターンを学習し、次に続くクラスを予測できるようになったことを示しています。今回の例では、["C", "D", "E", "A"] に対して次に来る文字(クラス)がそれぞれ D, E, A, B であると予測しています。
今後の課題
今回のRNNモデルはシンプルなデータでの訓練結果ですが、より多様なシーケンスや複雑なデータに対しても性能を発揮できるよう、さらなる調整が必要です。例えば以下の点に注力して改善していきます。
データセットの拡充:より多様で大きなデータセットを用意する。
モデルのチューニング:隠れ層のサイズや学習率など、ハイパーパラメータの調整。
長期依存の学習:LSTMやGRUなど、長期依存を扱えるモデルの導入。
今後は、RNNの拡張版としてLSTM(Long Short-Term Memory)をスクラッチで実装し、長期依存関係の学習に挑戦してみたいと思います。