DeepLabデモのコードを理解していく
前回はDeepLabデモを実行したので、あらためてデモの実行内容を理解していく。
対象ファイルは以下のとおり
ensorflow/models/research/deeplab/deeplab_demo.ipynb
セル1
タイトル(DeepLab Demo)
このデモは、サンプル入力イメージでDeepLabのセマンティックセグメンテーションモデルを実行する手順を実演するよー。
セル2
モジュールのインポート。
#@title Imports
import os
from io import BytesIO
import tarfile
import tempfile
from six.moves import urllib
from matplotlib import gridspec
from matplotlib import pyplot as plt
import numpy as np
from PIL import Image
import tensorflow as tf
L1: タイトル
L3: パス操作に使用するOSインタフェースのモジュール
L4: 画像読込に使用するバイナリI/Oモジュール(I/Oモジュール所属)
L5: TARアーカイブのモジュール
L6: 一時ファイル・ディレクトリ作成のモジュール
L7: URLリクエストに使用するモジュール(six.movesはPython2との互換性ライブラリ)
L9: 複数グラフレイアウトモジュール(matplotlibはデータ可視化ライブラリ)
L10: 結果表示に使用するグラフ描画モジュールをpltとしてインポート
L11: 数値計算ライブラリをnpとしてインポート
L12: 画像処理ライブラリ
L14: TensorFlowライブラリをtfとしてインポート
セル3
ヘルパーメソッド
#@title Helper methods
class DeepLabModel(object):
"""Class to load deeplab model and run inference."""
INPUT_TENSOR_NAME = 'ImageTensor:0'
OUTPUT_TENSOR_NAME = 'SemanticPredictions:0'
INPUT_SIZE = 513
FROZEN_GRAPH_NAME = 'frozen_inference_graph'
L1: タイトル(ヘルパーメソッド)
L4: DeepLabModelクラスの宣言
L5: コメント行(DeepLabモデルを読み込みとインタフェース実行のクラス)
L6 入力テンソル名INPUT_TENSOR_NAME
L7: 出力テンソル名OUTPUT_TENSOR_NAME
L9: 入力画像のサイズ(高さ・画像いずれか)
L10: 読み込むモデルのファイル名
def __init__(self, tarball_path):
"""Creates and loads pretrained deeplab model."""
self.graph = tf.Graph()
graph_def = None
# Extract frozen graph from tar archive.
tar_file = tarfile.open(tarball_path)
for tar_info in tar_file.getmembers():
if self.FROZEN_GRAPH_NAME in os.path.basename(tar_info.name):
file_handle = tar_file.extractfile(tar_info)
graph_def = tf.GraphDef.FromString(file_handle.read())
break
tar_file.close()
if graph_def is None:
raise RuntimeError('Cannot find inference graph in tar archive.')
with self.graph.as_default():
tf.import_graph_def(graph_def, name='')
self.sess = tf.Session(graph=self.graph)
L12: init()メソッドの定義
L13: コメント行(学習済DeepLabモデルを読み込む)
L14: Graphのインスタンス作成
L16: graph_defの初期化
L17: コメント行(TARアーカイブから圧縮したグラフを解凍)
L18: メソッドの引数として与えられたパスのTARファイルを開く
L19: TARアーカイブのメンバー毎に以L21〜24を繰り返し
L20: ファイル名に定数FROZEN_GRAPH_NAMEが含まれていればL21
L21: TARアーカイブから該当メンバをファイルオブジェクトとして抽出
L22: ファイルオブジェクトからモデルを読み込んでgraph_defに設定
L23: 繰り返し終了
L25: TARアーカイブを閉じる
L27-28: graph_defが設定されなかった場合、ランタイムエラーを発生させる
L30: Graphインスタンスにデフォルトグラフを設定して以下の処理
L31: デフォルトグラフにgraph_defをインポート
L33: Graphインスタンスでセッションを作成
def run(self, image):
"""Runs inference on a single image.
Args:
image: A PIL.Image object, raw input image.
Returns:
resized_image: RGB image resized from original input image.
seg_map: Segmentation map of `resized_image`.
"""
width, height = image.size
resize_ratio = 1.0 * self.INPUT_SIZE / max(width, height)
target_size = (int(resize_ratio * width), int(resize_ratio * height))
resized_image = image.convert('RGB').resize(target_size, Image.ANTIALIAS)
batch_seg_map = self.sess.run(
self.OUTPUT_TENSOR_NAME,
feed_dict={self.INPUT_TENSOR_NAME: [np.asarray(resized_image)]})
seg_map = batch_seg_map[0]
return resized_image, seg_map
L35: run()メソッドの定義
L36-44: コメント行 単一画像のインタフェースを実行
引数: image: PIL画像オブジェクト、RAW入力画像
戻り値:resized_image: リサイズしたRGB画像
seg_map: リサイズ画像の分類マップ
L45: 幅、高さを引数から取得
L46: 幅、高さの大きい方からリサイズ比率を計算
L47: リサイズ後の幅、高さを計算
L48: オリジナル画像をリサイズ
L49-51: セッションを実行
ドキュメントには載ってないがfeed_dictのキーに文字列を与えれる?
L52: セッション実行結果からマップを取得
L53: 戻り値を返却
def create_pascal_label_colormap():
"""Creates a label colormap used in PASCAL VOC segmentation benchmark.
Returns:
A Colormap for visualizing segmentation results.
"""
colormap = np.zeros((256, 3), dtype=int)
ind = np.arange(256, dtype=int)
for shift in reversed(range(8)):
for channel in range(3):
colormap[:, channel] |= ((ind >> channel) & 1) << shift
ind >>= 3
return colormap
L56: create_pascal_label_colormap()メソッドの定義
L57-61: コメント行 PASCAL VOC segmentation benchmarkで使用されるラベルカラーマップの作成
戻り値:分類結果を視覚化するためのカラーマップ
L62: カラーマップとして256×3のint配列を初期化
L63: indに[0, 1, 2, ... 253, 254, 255]のint配列を設定
L65: shift=7, 6, 5, 4, 3, 2, 1, 0の順でL66-68を繰り返し
L66: channel=0, 1, 2の順でL67を繰り返し
L67: colormap[all][channel]に、以下との論理和を設定
( (indをchannnelビット分右シフト)と1の論理積) をshiftビット分左シフト
L68: indを3ビット分右シフト
L69: coloemapを返却
※カラーマップは重複しない256色の配列となる。(RGBが同値にならないようビットシフトがずらしてある)
def label_to_color_image(label):
"""Adds color defined by the dataset colormap to the label.
Args:
label: A 2D array with integer type, storing the segmentation label.
Returns:
result: A 2D array with floating type. The element of the array
is the color indexed by the corresponding element in the input label
to the PASCAL color map.
Raises:
ValueError: If label is not of rank 2 or its value is larger than color
map maximum entry.
"""
if label.ndim != 2:
raise ValueError('Expect 2-D input label')
colormap = create_pascal_label_colormap()
if np.max(label) >= len(colormap):
raise ValueError('label value too large.')
return colormap[label]
L72: label_to_color_image()メソッドの定義
L73-86: コメント行 データセットカラーマップによって定義された色をラベル加える
引数: label: 分類ラベルを収納した、整数形式の二次元配列
戻り値: 浮動小数店形式の二次元配列。配列の要素は、PASCALカラーラベルへの入力の対応する要素のカラーインデックスになる。
例外: ValueError: ラベルがランク2ではない、または値がカラーマップの最大エントリー数より大きい場合
L87-88: ラベルの次元数が2でない場合、 ValueError
L90: create_pascal_label_colormap()メソッドを実行してカラーマップ取得
L92-93: ラベルの最大値がカラーマップの配列長より大きい場合、ValueError
L95: 分類マップから変換したカラーマップを返却
def vis_segmentation(image, seg_map):
"""Visualizes input image, segmentation map and overlay view."""
plt.figure(figsize=(15, 5))
grid_spec = gridspec.GridSpec(1, 4, width_ratios=[6, 6, 6, 1])
plt.subplot(grid_spec[0])
plt.imshow(image)
plt.axis('off')
plt.title('input image')
plt.subplot(grid_spec[1])
seg_image = label_to_color_image(seg_map).astype(np.uint8)
plt.imshow(seg_image)
plt.axis('off')
plt.title('segmentation map')
plt.subplot(grid_spec[2])
plt.imshow(image)
plt.imshow(seg_image, alpha=0.7)
plt.axis('off')
plt.title('segmentation overlay')
unique_labels = np.unique(seg_map)
ax = plt.subplot(grid_spec[3])
plt.imshow(
FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation='nearest')
ax.yaxis.tick_right()
plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
plt.xticks([], [])
ax.tick_params(width=0.0)
plt.grid('off')
plt.show()
L98: vis_segmentation()メソッドの定義
L99: コメント行(入力画像、分類マップ、オーバーレイの視覚化)
L100: 横15インチ、縦5インチでFigureインスタンスを作成
L101: 縦1個、横4個(比率 6:6:6:1)でグリッド作成
L103: グリッド[0]にプロット開始
L104: グリッド[0]に画像を表示
L105: グリッド[0]の軸を非表示
L106: グリッド[0]にタイトルを設定(input image)
L108: グリッド[1]にプロット開始
L109: 分類マップをカラーマップ化
L110: グリッド[1]にカラーマップを表示
L111: グリッド[1]の軸を非表示
L112: グリッド[1]にタイトルを設定(segmentation map)
L114: グリッド[2]にプロット開始
L115: グリッド[2]に画像を表示
L116: グリッド[2]に透過率0.7でカラーマップを表示
L117: グリッド[2]の軸を非表示
L118: グリッド[2]にタイトルを設定(segmentation overlay)
L120: 分類マップの重複を削除した一次元配列をunique_labelsに設定
L121: グリッド[3]のプロット開始
L122: グリッド[3]にunique_labelsで使用しているカラーマップを表示
L123: グリッド[3]のY軸の目盛り、ラベルを軸の右に移動
L124: グリッド[3]のY軸にunique_labelsで使用しているラベル名を表示
L125: グリッド[3]のX軸に何も表示しない
L126: グリッド[3]の目盛りの横幅を0.0に設定
L127: グリッドの区切りを非表示
L128: グリッドを表示実行
LABEL_NAMES = np.asarray([
'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tv'
])
FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
L131-135: ラベル名一覧を定義
L136: ラベル名から等差数列を作成
L137: 等差数列からカラーマップを作成
セル4
DeepLabのモデルをインターネットからDLして読み込む。
#@title Select and download models {display-mode: "form"}
MODEL_NAME = 'mobilenetv2_coco_voctrainaug' # @param ['mobilenetv2_coco_voctrainaug', 'mobilenetv2_coco_voctrainval', 'xception_coco_voctrainaug', 'xception_coco_voctrainval']
_DOWNLOAD_URL_PREFIX = 'http://download.tensorflow.org/models/'
_MODEL_URLS = {
'mobilenetv2_coco_voctrainaug':
'deeplabv3_mnv2_pascal_train_aug_2018_01_29.tar.gz',
'mobilenetv2_coco_voctrainval':
'deeplabv3_mnv2_pascal_trainval_2018_01_29.tar.gz',
'xception_coco_voctrainaug':
'deeplabv3_pascal_train_aug_2018_01_04.tar.gz',
'xception_coco_voctrainval':
'deeplabv3_pascal_trainval_2018_01_04.tar.gz',
}
_TARBALL_NAME = 'deeplab_model.tar.gz'
model_dir = tempfile.mkdtemp()
tf.gfile.MakeDirs(model_dir)
download_path = os.path.join(model_dir, _TARBALL_NAME)
print('downloading model, this might take a while...')
urllib.request.urlretrieve(_DOWNLOAD_URL_PREFIX + _MODEL_URLS[MODEL_NAME],
download_path)
print('download completed! loading DeepLab model...')
MODEL = DeepLabModel(download_path)
print('model loaded successfully!')
L1: タイトル(モデルの選択とダウンロード)
L3: 実際に使用するモデル名を定義(後ろの配列から1つを記述)
L5: DLするURLの固定部分を定義
L6-15: モデル名とDL先tar.gzファイル名の辞書を定義
L16: DL後のtar.gzファイル名を定義
L18: 一時ディレクトリの作成
L19: L18で取得したディレクトリ名で再度ディレクトリ作成
L21: 一時ディレクトリとtar.gzファイル名を結合してDL先パスを作成
L22: DL開始をプリント
L23-24: DL実行
L25: DL完了、モデル読込開始をプリント
L27: DLしたファイルを与えてセル3で定義したDeepLabModelのインスタンスを作成
L28: モデル読込成功をプリント
セル5
サンプル画像による実行の説明
セル6のIMAGE_URLを空のままSAMPLE_IMAGEをどれか設定するか、IMAGE_URLに任意の画像URLを設定してね。
注意!:このデモは高速化のために単一スケール推論を実行しているよ。だから結果は、複数スケール推論と左右反転入力を使用するREADMEのビジュラライゼーションと少し違ってるよ。
セル6
セル4で読み込んだモデルを使用して、サンプル画像をセマンティックセグメンテーションする。
#@title Run on sample images {display-mode: "form"}
SAMPLE_IMAGE = 'image1' # @param ['image1', 'image2', 'image3']
IMAGE_URL = '' #@param {type:"string"}
_SAMPLE_URL = ('https://github.com/tensorflow/models/blob/master/research/'
'deeplab/g3doc/img/%s.jpg?raw=true')
L1: タイトル(サンプル画像による実行)
L3: サンプル画像名の定義(後ろの配列から1つを記述)
L4: 画像URL(任意の画像URLを記述可能)
L6-7: サンプル画像URLの定義
def run_visualization(url):
"""Inferences DeepLab model and visualizes result."""
try:
f = urllib.request.urlopen(url)
jpeg_str = f.read()
original_im = Image.open(BytesIO(jpeg_str))
except IOError:
print('Cannot retrieve image. Please check url: ' + url)
return
print('running deeplab on image %s...' % url)
resized_im, seg_map = MODEL.run(original_im)
vis_segmentation(resized_im, seg_map)
L10: run_visualization()メソッドの定義
L11: コメント行(DeepLabモデルの推論と結果の視覚化)
L13: 引数で与えられた画像URLを開いて、HTTPResponseオブジェクトを取得
L14: HTTPResponseオブジェクトをバイナリデータで読み込む
L15: バイナリデータを画像に変換
L16-18: IOError時、エラーをプリントしてメソッド終了
L20: モデル実行中をプリント
L21: 画像を引数にして、セル3で定義したrun()メソッドを実行
L23: run()メソッドの結果を引数にして、vis_segmentation()メソッドを実行
image_url = IMAGE_URL or _SAMPLE_URL % SAMPLE_IMAGE
run_visualization(image_url)
main処理
L26: 画像URL作成(IMAGE_URLが空の場合、SAMPLE_URL中の%sをSAMPLE_IMAGEに変換して使用)
L27: 画像URLを引数にして、定義したrun_visualization()メソッドを実行
感想
・自作モデルを使用したい場合は5セル目のDL周りを変えればいいっぽい。
・自作モデルのモデル名、テンソル名、画像サイズ、ラベル名は必要に合わせて定義と合わせる。
・推論のスケールと左右反転入力の意味がわかってないので要調査。