scikit-learnとSVMで画像文字認識

どうもおはこんばんにちは。マタキチです。
近所に花屋さんで面白いサボテンを発見しました。なんでも電磁波サボテンという異名を持つサボテン。その名もセレウス。いや神様かっつって。神々しい名前かって。こちとら大和の民だぞって。なんでもこちらのサボテン、NASAの研究では電磁波を食すとかなんとか。サボテンは好きなので興味津々なのですが、これはいったいどんなサボテンなんでしょうか。どなたか教えていただけたら嬉しいです。そんな今日この頃。

pythonでスクレイピングやら機械学習を勉強しているので記述していきます。間違いありましたらご指摘いただけたら嬉しいです。

手書き数字の認識

今回は簡単な画像認識をしてみましょう。機械学習のメジャーライブラリ「scikit-learn」のサンプルデータにはあらかじめ手書き数字のサンプルが用意されているので手軽にサンプル実行できる準備が整っているのです。

手書き数字データの入手

いろいろ調べてみましたが、MNISTが公開しているデータを利用することにします。練習用に6万、テスト用に1万のデータがあるのでこれにしてみました。サイトはこちら。

サイト:THE MNIST DATABASE of handwritten digits
 URL:http://yann.lecun.com/exdb/mnist/

サイトから入手できるデータはgzip形式で圧縮されています。ではダウンロードからgzip解凍してみましょう。

▽ファイル名:download.py

import urllib.request as req
import gzip, os, os.path

# 保存するディレクトリとURLを指定
savepath = "./mnist"
baseurl = "http://yann.lecun.com/exdb/mnist/"
files = ["train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz", "t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz"]

# ダウンロード開始
if not os.path.exists(savepath): os.mkdir(savepath)
for f in files:
   url = baseurl + "/" + f
   loc = savepath + "/" + f
   print("download :", url)
   if not os.path.exists(loc): req.urlretrieve(url, loc)

# gzip形式を解凍
for f in files:
   gz_files = savepath + "/" + f
   raw_file = savepath + "/" + f.replace(".gz", "")
   print("gzip:", f)
   with gzip.open(gz_files, "rb") as fp:
       body = fp.read()
       with open(raw_file, "wb") as w:
           w.write(body)

# 成功したらなんか表示させる
print("OK!")

実行結果はこちら。デスクトップにmnistファイルと解凍後のファイルのダウンロードが完了しています。

スクリーンショット 2019-10-14 21.30.40

画像データの中身ですが、0がは背景で1-255が黒です。ちなみに数字が大きくなるほど濃い黒であることを表しています。また、バイナリーデータなのでcsvデータに変換してみましょう。

▽ファイル名:tocsv.py
import struct


def to_csv(name, maxdata):

   #ラベルファイルとイメージファイルを開く
   lbl = open("./mnist/"+name+"-labels-idx1-ubyte", "rb")
   img = open("./mnist/"+name+"-images-idx3-ubyte", "rb")
   csv = open("./mnist/"+name+".csv", "w", encoding="utf-8")

   #ヘッダー情報を読む
   mag, lbl_count = struct.unpack(">II", lbl.read(8))
   mag, img_count = struct.unpack(">II", img.read(8))
   rows, cols = struct.unpack(">II", img.read(8))
   pixels = rows * cols

   #画像データを読んでCSVで保存します
   res = []
   for idx in range(lbl_count):
       if idx > maxdata: break
       label = struct.unpack("B", lbl.read(1))[0]
       bdata = img.read(pixels)
       sdata = list(map(lambda n: str(n), bdata))
       csv.write(str(label)+",")
       csv.write(",".join(sdata)+"\r\n")

       #うまく取り出せたかどうかPGMで保存してテストします
       if idx < 10:
           s = "P2 28 28 255\n"
           s += " ".join(sdata)
           iname = "./mnist/{0}-{1}-{2}.pgm".format(name, idx, label)
           with open(iname, "w", encoding="utf-8") as f:
               f.write(s)
   csv.close()
   lbl.close()
   img.close()

#学習用データ5000件
to_csv("train", 5000)
#テスト用データ2000件
to_csv("t10k", 2000)

実行すると先ほどmnistディレクトリにCSVファイルが2つ出力します。t10k.csvとtrain.csvというデータです。open()で開く時、バイナリーファイルを開くことを表す"br"をつけます。また、任意のバイトを読み込んで整数として読むにはstruct.unpack()を利用します。

画像データを学習させよう

ここまでで準備で得られた画像データを用いて機械学習を実践していきましょう。

▽ファイル名:mnist-train.py
from sklearn import svm, metrics

# CSVファイルを読むんで学習用データに整形
def load_csv(fname):
   label = []
   images = []
   with open(fname, "r") as f:
       for line in f:
           cols = line.split(",")
           if len(cols) < 2: continue
           label.append(int(cols.pop(0)))
           vals = list(map(lambda n: int(n) / 256, cols))
           images.append(vals)
   return {"labels":label, "images":images}

data = load_csv("./mnist/train.csv")
test = load_csv("./mnist/t10k.csv")

# 学習
clf = svm.SVC()
clf.fit(data["images"], data["labels"])

# 予測
predict = clf.predict(test["images"])

# 結果がどの程度合っていたか認識
ac = metrics.accuracy_score(test["labels"], predict)
cl = metrics.classification_report(test["labels"], predict)
print("正解率=", ac)
print("レポート:")
print(cl)

実行結果がこちらです。

スクリーンショット 2019-10-14 23.44.57

あれまっ。結果が正解率が0.12くらいとあまりよろしくないです。これはデータの件数が足りないのか、それともデータの質がよくないのか、CSVの作り方を間違えてるのか(バイナリデータなので区切りかたを間違えているのか)。うーん。これは次回にじっくり考えたいと思います。

-追記-
画像データを確認したところ、まっくろくろすけでした。データの切り出し方を間違えてる模様。誰かヘルプミー、ビリーブミーです。
-追記2-
原因がわかりました。ヘッダー情報を読み混んだ際に、画像データをpixelsとしていますが、横ピクセルと縦ピクセルを掛け算(row * cols)しておりませんでした。コードを書き直して実行した結果が下記になります。

スクリーンショット 2019-10-15 20.39.53

正答率が約0.79と大幅に向上しましたが・・・いい結果とは言い難いのではないでしょうか。解決できたのでよかったよかった。

まとめ

今回の記事では結果はよくない結果になりました。でも正答率が出せるプログラミングは組めたので一応達成なのかな?と思いたいです。なんでかはじっくりこの後考察して、今後この記事に追記していきたいと思います。以上ですーー。

いいなと思ったら応援しよう!