【Go・Deep Learning その④】 ゼロから作るdeep learning 3を復習 ーRNNとLSTM
これまで
今回
CNNは多次元配列がないと面倒くさそうだったため、飛ばしました。🙇♂️
なので、CNN関係+ラズパイあたりで近々記事を出そうと思います。🙇♂️
許してください。
今回はRNN・LSTMです。
時系列データとかに用いられるイメージですかね?(株とか?)
RNN ⏰
簡単にいうと隠れ層が繰り返し状態を使う感じですかね。。
なにはともわれドン(ソースコード👇)
ソースコード
const (
maxEpoch = 100
hiddenSize = 100
bpttLength = 30
)
trainSet := datasets.NewSinCurve(dz.Train(true))
seqlen := trainSet.Len()
model := models.NewSimpleRNN(hiddenSize, 1)
optimizer := optimizers.NewAdam().Setup(model)
for i := 0; i < maxEpoch; i++ {
model.ResetState()
loss, count := dz.NewVariable(core.New1D(0)), 0
for j := 0; j < trainSet.Len(); j++ {
xData, tData := trainSet.Get(j)
x := dz.AsVariable(xData)
t := dz.AsVariable(tData)
y := model.Apply(x).First()
lossVar := fn.MeanSquaredError(y, t)
loss = fn.Add(loss, lossVar)
if count += 1; (count%bpttLength) == 0 || count == seqlen {
model.ClearGrads()
loss.Backward()
loss.UnchainBackward()
optimizer.Update()
}
}
avgLoss := loss.Data().At(0, 0) / float64(count)
fmt.Println("epoch", i+1, " | loss", avgLoss)
}
簡単に説明しますと、
Sin波を学習しているのですが、
参考に記事でも貼っておきますね💔
https://qiita.com/kazukiii/items/df809d6cd5d7d1f57be3
結果 🟦: 予測線 / 🟩: Sin波
ある程度が予測できていそう。。。。
LSTM 💪
こちらは簡単に言ってしまえばRNNの進化版みたいなイメージですかね?
違いは中間層を設けるかって感じですかね? 参考<-わかりやすい
何はともわれ実装をドン
ソースコード
const (
maxEpoch = 100
hiddenSize = 100
batchSize = 30
bpttLength = 30
)
trainSet := datasets.NewSinCurve(dz.Train(true))
dataloader := loader.NewSeqDataLoader(trainSet, batchSize)
seqlen := trainSet.Len()
model := models.NewBetterRNN(hiddenSize, 1)
optimizer := optimizers.NewAdam().Setup(model)
for i := 0; i < maxEpoch; i++ {
model.ResetState()
loss, count := dz.NewVariable(core.New1D(0)), 0
for dataloader.Next() {
x, t := dataloader.Read()
y := model.Apply(x).First()
lossVar := fn.MeanSquaredError(y, t)
loss = fn.Add(loss, lossVar)
if count += 1; (count%bpttLength) == 0 || count == seqlen {
model.ClearGrads()
loss.Backward()
loss.UnchainBackward()
optimizer.Update()
}
}
avgLoss := loss.Data().At(0, 0) / float64(count)
fmt.Println("epoch", i+1, " | loss", avgLoss)
}
今回はバッチサイズを利用して効率よく学習を行いました。
基本的にRNNの時とそこまで差があるようには見えませんね😅
結果 🟦: 予測線 / 🟩: Sin波
RNNの時よりも予測が近いですね!これで、僕も未来を見据える
ソースコード
まとめ✅
今回の記事は歴代トップに内容が薄いですね。。。。。🙇♂️
これでDeZeroのGonum実装は終わりです!
CNNを飛ばしてしまったので、近々CNN関係の記事を書こうと思います。
参考書を写経しているだけなのになかなか辛かったです。。。。😓
ゼロから作るDeep Learningの1・3を読めばもう楽しくてやばいですね。
PS
「ゼロから作るDeep Learning ❹ ―強化学習編」が出るので楽しみですね!
まずはPythonでしっかり理解してから別言語で実装してみたを作りたいと思います笑