【Pythonで生きる】MNISTを使っての画像認識の機械学習アプリを作ってみた
今回のゴールはとりあえず動くモノを作る、です。
細かい部分はチームでやったので、僕にも理解不能な部分もありますが、一旦動くモノが出来た&コードとしても良いものになったので、記事にしました。
こちらの完成品のコードは、再利用可能です。
購入して頂いた方は自由に解読・改変して頂いて構いません。
背景
ニートがPython を使って、日々生きていく中で毎日のルーティン作業を自動化することができることに気付いてしまいました。例えば、競馬の自動予測アプリを作って、日々の収入の足しにすることなどが出来ます。
巷では、仮想通貨で稼いでいる方もたくさんいらっしゃいますが、僕はITを駆使しながら、ニートライフを送っております。
競馬の自動予測アプリなどは、世間にあまり出ておらず、課金して使うモノなどは多数出ておりますが、僕は自分で作って学習コストはかかるものの、運用コスト0で収入を得たりしようとしています。
競馬の自動予測ツールは有料課金で販売されていてます。1ヶ月8000円のものもあれば、さらに上を行くモノもたくさんあります。FXも例外ではありませんし、現在その開発も行っております。
今回作るのは、機械学習を始めるに当たり一番取っつきやすい機械学習のアプリケーションの作り方を記したいと思います。
機械学習が使えるエンジニアの給料が上がっている
あくまで参考ですが、
このような記事を見ると、年収1500万円以上でもいいエンジニアなら獲得したい!のような記事もありますし、
どこかで見たような、企業名もチラホラとある中での年収の高さも気になります。
今回は機械学習の最初の一歩ということで、MNISTというブラウザで数字を描画して、その数字を認識してもらうというものを作ってみました。
目次
・MNISTで手書きの数字を認識できるアプリケーションを作ろう
・必要な知識
・何が出来るようになるか
・フローの確認
・MNISTの元データを引っ張ってこよう
・いよいよ機械学習のコードを実装
・実行して結果の確認
・ブラウザで描画してみよう
MNISTで手書きの数字を認識できるアプリケーションを作ろう
今回の完成形のイメージとしては、NN(Neural network)を使ったシンプルな機械学習のアプリケーションの開発です。ある程度の機械学習なら自分で実装できるようになるレベルを目指す+動くモノを作れるようになるのがゴールです。
必要な知識
・Pythonの基礎知識
・コマンドプロント・ターミナルの基本的な使い方
・PythonとAnacondaの環境構築
Pythonの基礎知識は、『みんなのPython第4版』がオススメです。
僕個人的には非常にわかりやすく、いつも持ち歩いています。というぐらいに気に入っています。
なにが出来るようになるか?
機械学習アプリのエントリーレベルなので、機械学習で使うライブラリーであったり、計算式の組み方、ロジック面の理解、そしてフロントエンドとバックエンドをどう繋ぐか?のような一連の流れは理解できるようになるかと思います。
具体的に書くと
・Pythonを使っての機械学習の基礎インプット
・Pythonを使って重みをアウトプットし、フロントとの連携のスクリプト作成方法
フローの確認(前提)
ブラウザで[0]から[9]の間での数字をマウスを使って描画して、自分が描画した数字は何番かを判定してもらいます。
フローの確認(実際)
まず、フロントエンド側はCanvasタグを使って、JavaScriptでの描画の実装をします。
次に、NeuralNetの実装を行います。
NeuralNetの実装はMNISTの画像データを読み込み学習させます。
MNISTとは手書き数字画像60000枚とテスト画像10000枚を集めた画像データセットになります。
次に学習データの重みを、XXX.datというファイル形式で出力して、ブラウザ側でロードします。
そうすると、学習データを元に判定していきます。
新しいインプットがあるたびに、学習能力も向上する仕組みになっています。
MNISTの元データを引っ張ってこよう
まずは、mnist.pyというファイルを作りましょう。
中身です。
# coding: utf-8
try:
import urllib.request
except ImportError:
raise ImportError('You should use Python 3.x')
import os.path
import gzip
import pickle
import os
import numpy as np
url_base = 'http://yann.lecun.com/exdb/mnist/'
key_file = {
'train_img':'train-images-idx3-ubyte.gz',
'train_label':'train-labels-idx1-ubyte.gz',
'test_img':'t10k-images-idx3-ubyte.gz',
'test_label':'t10k-labels-idx1-ubyte.gz'
}
dataset_dir = os.path.dirname(os.path.abspath(__file__))
save_file = dataset_dir + "/mnist.pkl"
train_num = 60000
test_num = 10000
img_dim = (1, 28, 28)
img_size = 784
def _download(file_name):
file_path = dataset_dir + "/" + file_name
if os.path.exists(file_path):
return
print("Downloading " + file_name + " ... ")
urllib.request.urlretrieve(url_base + file_name, file_path)
print("Done")
def download_mnist():
for v in key_file.values():
_download(v)
def _load_label(file_name):
file_path = dataset_dir + "/" + file_name
print("Converting " + file_name + " to NumPy Array ...")
with gzip.open(file_path, 'rb') as f:
labels = np.frombuffer(f.read(), np.uint8, offset=8)
print("Done")
return labels
def _load_img(file_name):
file_path = dataset_dir + "/" + file_name
print("Converting " + file_name + " to NumPy Array ...")
with gzip.open(file_path, 'rb') as f:
data = np.frombuffer(f.read(), np.uint8, offset=16)
data = data.reshape(-1, img_size)
print("Done")
return data
def _convert_numpy():
dataset = {}
dataset['train_img'] = _load_img(key_file['train_img'])
dataset['train_label'] = _load_label(key_file['train_label'])
dataset['test_img'] = _load_img(key_file['test_img'])
dataset['test_label'] = _load_label(key_file['test_label'])
return dataset
def init_mnist():
download_mnist()
dataset = _convert_numpy()
print("Creating pickle file ...")
with open(save_file, 'wb') as f:
pickle.dump(dataset, f, -1)
print("Done!")
def _change_one_hot_label(X):
T = np.zeros((X.size, 10))
for idx, row in enumerate(T):
row[X[idx]] = 1
return T
def load_mnist(normalize=True, flatten=True, one_hot_label=False):
"""MNISTデータセットの読み込み
Parameters
----------
normalize : 画像のピクセル値を0.0~1.0に正規化する
one_hot_label :
one_hot_labelがTrueの場合、ラベルはone-hot配列として返す
one-hot配列とは、たとえば[0,0,1,0,0,0,0,0,0,0]のような配列
flatten : 画像を一次元配列に平にするかどうか
Returns
-------
(訓練画像, 訓練ラベル), (テスト画像, テストラベル)
"""
if not os.path.exists(save_file):
init_mnist()
with open(save_file, 'rb') as f:
dataset = pickle.load(f)
if normalize:
for key in ('train_img', 'test_img'):
dataset[key] = dataset[key].astype(np.float32)
dataset[key] /= 255.0
if one_hot_label:
dataset['train_label'] = _change_one_hot_label(dataset['train_label'])
dataset['test_label'] = _change_one_hot_label(dataset['test_label'])
if not flatten:
for key in ('train_img', 'test_img'):
dataset[key] = dataset[key].reshape(-1, 1, 28, 28)
return (dataset['train_img'], dataset['train_label']), (dataset['test_img'], dataset['test_label'])
if __name__ == '__main__':
init_mnist()
これがいわゆるmnistの元からデータを引っ張ってくるものになります。
次にHTMLのコーディングを行います。
こちらは、ブラウザで描画してもらうためのインターフェイスになります。
<!DOCTYPE html>
<html>
<head>
<title>MNIST Test</title>
</head>
<body>
<H1> MNIST Check </H1>
<div>
<input type="file" id="file" accept=".dat">
</div>
<script type="text/javascript" src="draw.js"></script>
<script type="text/javascript" src="model.js"></script>
<script type="text/javascript" src="judge.js"></script>
<canvas id='canvas' width=448 height=448 style="background-color:black;"></canvas>
<canvas id='canvas2' width=448 height=448 style="background-color:black;"></canvas>
<div>
<button type="button" onclick="check()" id ="checkButton" class="button" disabled>Check</button>
<button type="button" onclick="allClear()" class="button">Clear</button>
</div>
<div id="result" class="result">
</div>
</body>
</html>
input.htmlとでもしておきましょう。
そして、それぞれ読み込んでいるjsファイルを記述していきます。
prevX=0;
prevY=0;
mouseDown=false;
function drawSetup(canvas, canvas2){
canvas.onmousedown = function(e){
var r = canvas.getBoundingClientRect();
prevX=e.clientX - r.left;
prevY=e.clientY - r.top;
mouseDown=true;
}
canvas.onmousemove = function(e){
if(mouseDown){
var r = canvas.getBoundingClientRect();
x =e.clientX - r.left;
y =e.clientY - r.top;
draw(x,y, canvas);
}
}
canvas.onmouseup =function(e){
mouseDown=false;
copy(canvas, canvas2);
}
}
function draw(x,y, canvas){
var context=canvas.getContext('2d');
context.strokeStyle="white";
var w = 40;
context.lineWidth=w;
context.lineCap="round";
context.lineJoin="round";
context.beginPath();
context.moveTo(prevX,prevY);
context.lineTo(x,y);
context.closePath();
context.stroke();
prevX=x;
prevY=y;
}
function copy(canvas, canvas2){
var h = canvas.height;
var w = canvas.width;
img = canvas.getContext('2d').getImageData(0,0,h,w);
data = img.data
for(var i=0;i<28;i++){
for(var j=0;j<28;j++){
var sum = 0;
for(var k=0;k<16;k++){
for(var l=0;l<16;l++){
x = i*16+k;
y = j*16+l;
var s = x+y*16*28;
if (data[s*4]>128){
sum++;
}
}
}
for(var k=0;k<16;k++){
for(var l=0;l<16;l++){
x = i*16+k;
y = j*16+l;
var s = x+y*16*28;
data[s*4] = sum;
data[s*4+1] = sum;
data[s*4+2] = sum;
data[s*4+3] = 255;
}
}
}
}
canvas2.getContext('2d').putImageData(img,0,0);
}
function getX(canvas){
var h = canvas.height;
var w = canvas.width;
img = canvas.getContext('2d').getImageData(0,0,h,w);
var x = new Float32Array(28*28);
data = img.data
for(var i=0;i<28;i++){
for(var j=0;j<28;j++){
var sum = 0;
for(var k=0;k<16;k++){
for(var l=0;l<16;l++){
sx = i*16+k;
sy = j*16+l;
var s = sx+sy*16*28;
if (data[s*4]>128){
sum++;
}
}
}
x[i+j*28] = sum/256.0;
}
}
//console.log(x);
return x;
}
function canvasClear(canvas){
var context=canvas.getContext('2d');
context.fillStyle="black";
context.fillRect(0,0,canvas.width,canvas.height);
}
こちらはdraw.jsとしましょう。
window.onload = function(){
drawSetup($('canvas'),$('canvas2'));
}
var model;
function onFileSelect(e) { var f = e.target.files;
var reader = new FileReader();
reader.onload = function(filename){
var fs = new Float32Stream(reader.result);
model = new Model(fs);
$('checkButton').disabled = "";
}
reader.readAsArrayBuffer(f[0]);
}
$('file').addEventListener('change', onFileSelect, false);
function check() {
var x = getX($('canvas'));
var i = model.recognize(x);
$('result').innerHTML = "Your figure is " + i + "!";
}
function allClear(){
$('result').innerHTML = "";
canvasClear($('canvas'));
canvasClear($('canvas2'));
}
function $(id){
return document.getElementById(id);
}
こちらは、judge.jsです。
そして、最後にmodel.jsです。
Link = function(n_in, n_out, fs) {
this.n_in = n_in;
this.n_out = n_out;
this.W = fs.a.slice(fs.index,fs.index+n_in*n_out);
var NW = new Float32Array(n_in*n_out);
for(var i = 0; i < this.n_in; i++){
for(var j = 0; j < this.n_out; j++){
NW[n_in*j+i] = this.W[i*n_out+j];
}
}
for(var k = 0; k < n_in*n_out; k++){
this.W[k] = NW[k];
}
console.log(fs.index);
//console.log(this.W[0]);
fs.index += this.W.length;
this.b = fs.a.slice(fs.index,fs.index + n_out);
fs.index += this.b.length;
}
Link.prototype.hello = function(){
console.log(this.W[0]);
console.log(this.b[0]);
}
Link.prototype.get = function(x){
var y = new Float32Array(this.n_out);
y.fill(0.0);
for(var i = 0; i < this.n_out; i++){
y[i] = 0.0;
for(var j = 0; j < this.n_in; j++){
y[i] += this.W[i*this.n_in + j]*x[j];
}
y[i] += this.b[i];
}
return y;
};
Model = function(fs) {
this.n_in = 28 * 28;
//this.n_units = 28 * 28;
this.n_units_hidden = 200;
this.n_out = 10;
this.l1 = new Link(this.n_in, this.n_units_hidden, fs);
this.l2 = new Link(this.n_units_hidden, this.n_out, fs);
// console.log(this.l2.W);
};
Model.prototype.predict = function(x){
y1 = this.l1.get(x);
z1 = this.sigmoid(y1);
y2 = this.l2.get(z1);
z2 = this.softmax(y2);
console.log('確率可視化' + z2);
return z2;
};
Model.prototype.sigmoid = function(x){
var y = new Float32Array(x.length);
for(var i = 0; i < x.length; i++){
y[i] = 1.0 / (1.0 + Math.exp(-x[i]));
}
return y;
};
Model.prototype.softmax = function(x){
var y = new Float32Array(x.length);
var sum;
sum = 0.0;
for(var i = 0; i < x.length; i++){
sum += Math.exp(x[i]);
}
for(var i = 0; i < x.length; i++){
y[i] = Math.exp(x[i])/sum;
}
return y;
};
Model.prototype.recognize = function(x){
var y = this.predict(x);
return y.indexOf(Math.max.apply(null,y));
}
Float32Stream = function(result){
this.a = new Float32Array(result);
this.index = 0;
};
そして、このmodel.jsがかなりのキーになっています。