見出し画像

【Go・Deep Learning その④】 ゼロから作るdeep learning 3を復習 ーRNNとLSTM

これまで

今回

CNNは多次元配列がないと面倒くさそうだったため、飛ばしました。🙇‍♂️
なので、CNN関係+ラズパイあたりで近々記事を出そうと思います。🙇‍♂️
許してください。

今回はRNN・LSTMです。
時系列データとかに用いられるイメージですかね?(株とか?)

RNN ⏰

簡単にいうと隠れ層が繰り返し状態を使う感じですかね。。
なにはともわれドン(ソースコード👇)

ソースコード


const (
	maxEpoch   = 100
	hiddenSize = 100
	bpttLength = 30
)

trainSet := datasets.NewSinCurve(dz.Train(true))
seqlen := trainSet.Len()

model := models.NewSimpleRNN(hiddenSize, 1)
optimizer := optimizers.NewAdam().Setup(model)

for i := 0; i < maxEpoch; i++ {
	model.ResetState()

	loss, count := dz.NewVariable(core.New1D(0)), 0
	for j := 0; j < trainSet.Len(); j++ {
		xData, tData := trainSet.Get(j)
		x := dz.AsVariable(xData)
		t := dz.AsVariable(tData)
		y := model.Apply(x).First()
		lossVar := fn.MeanSquaredError(y, t)
		loss = fn.Add(loss, lossVar)

		if count += 1; (count%bpttLength) == 0 || count == seqlen {
			model.ClearGrads()
			loss.Backward()
			loss.UnchainBackward()
			optimizer.Update()
		}
	}
	avgLoss := loss.Data().At(0, 0) / float64(count)
	fmt.Println("epoch", i+1, " | loss", avgLoss)
}

簡単に説明しますと、
Sin波を学習しているのですが、
参考に記事でも貼っておきますね💔

https://qiita.com/kazukiii/items/df809d6cd5d7d1f57be3



結果 🟦: 予測線 /  🟩: Sin波

ある程度が予測できていそう。。。。

LSTM 💪

こちらは簡単に言ってしまえばRNNの進化版みたいなイメージですかね?
違いは中間層を設けるかって感じですかね?     参考<-わかりやすい
何はともわれ実装をドン

ソースコード

const (
	maxEpoch   = 100
	hiddenSize = 100
	batchSize  = 30
	bpttLength = 30
)

trainSet := datasets.NewSinCurve(dz.Train(true))
dataloader := loader.NewSeqDataLoader(trainSet, batchSize)
seqlen := trainSet.Len()

model := models.NewBetterRNN(hiddenSize, 1)
optimizer := optimizers.NewAdam().Setup(model)

for i := 0; i < maxEpoch; i++ {
	model.ResetState()

	loss, count := dz.NewVariable(core.New1D(0)), 0
	for dataloader.Next() {
		x, t := dataloader.Read()
		y := model.Apply(x).First()
		lossVar := fn.MeanSquaredError(y, t)
		loss = fn.Add(loss, lossVar)
		if count += 1; (count%bpttLength) == 0 || count == seqlen {
			model.ClearGrads()
			loss.Backward()
			loss.UnchainBackward()
			optimizer.Update()
		}
	}
	avgLoss := loss.Data().At(0, 0) / float64(count)
	fmt.Println("epoch", i+1, " | loss", avgLoss)
}

今回はバッチサイズを利用して効率よく学習を行いました。
基本的にRNNの時とそこまで差があるようには見えませんね😅

結果 🟦: 予測線 /  🟩: Sin波

RNNの時よりも予測が近いですね!
これで、僕も未来を見据える

ソースコード


まとめ✅

今回の記事は歴代トップに内容が薄いですね。。。。。🙇‍♂️
これでDeZeroのGonum実装は終わりです!

CNNを飛ばしてしまったので、近々CNN関係の記事を書こうと思います。
参考書を写経しているだけなのになかなか辛かったです。。。。😓

ゼロから作るDeep Learningの1・3を読めばもう楽しくてやばいですね。

PS

「ゼロから作るDeep Learning ❹ ―強化学習編」が出るので楽しみですね!
まずはPythonでしっかり理解してから別言語で実装してみたを作りたいと思います笑


いいなと思ったら応援しよう!