見出し画像

馬をシマウマに、シマウマを馬にするCycle GAN のおはなし(論文解説)

 こんにちは、こんばんは、teftefです。NovelAI が登場したのでここ最近の記事は作品、Prompt 紹介をでいっぱいになっていました。少し前に敵対的生成ネットワーク、通称 GAN について基本的なところについて書きました。記事を載せておきます。今回はいくつかある GAN うちの中でもCycle-GANについて書いていきたいと思います。前半では高校生レベルの数学の知識があれば理解できる内容、後半で少し専門的なことを書こうと思います。基本的には丁寧に説明するので、流れだけ知りたいという方も是非読んでいってもらえると嬉しいです!
 それでは行きます。

論文

 今回この記事で元にした論文はこの 2 件です
https://arxiv.org/abs/2103.03467 (↑)
https://arxiv.org/abs/1703.10593 (↓)

上の論文は CycleGAN を高速化する手法と馬をシマウマにする方法であり、下の論文は CycleGAN の基本原理や手法について書いてあります。それではまずCycleGANは何をやっているかから見ていきましょう。

まずは用語から

教師あり学習

 まずは『教師あり学習』という言葉から。これは機械学習の手法の一種であり、AI にデータとその正解情報を渡して学習させます。例えば 0~9 の手書き数字の画像から何が書かれているかを判別する AI があります (mnist) 。これは機械学習を始めた時にチュートリアルで使われることが多いのですが、画像に 0~9 の画像に、それぞれの画像がどの数と対応しているかを示す "ラベル" が貼ってあります。これで AI が正しいラベルの数値を出すまで学習させます。これが教師あり学習です。

ラベル付けの例

教師なし学習

 対して、『教師なし学習』というのは機械学習の手法の一種であり、データに正解を与えない状態で AI を学習させます。例えば下のような点を2つの大きなグループにまとめられそうですよね。しかし分け方に正解はないですよね。AIに対して出す指示は「この点群をうまい感じにクラス分け(クラスタリング)しといて~」という感じで、どうクラス分けするか、何個に分けるのかという情報は AI 自身に考えてもらいます。これが教師なし学習です。

クラスタリングの例CycleGAN

CycleGAN

CycleGANの前に

 簡単に一言で説明すると CycleGAN は GAN の手法のうちの1つであり、 2 つの Unpaired (ペアでない) 画像を用意して、教師なし学習を行っています。GANについては以前の記事で書いたので、ぜひそれをご覧ください。

Pix2Pix

 CycleGANのと似た方法の一つに Pix2Pix があり、これは教師あり学習です。Pix2Pix はこのように絵の輪郭があり、「これに色を付けてくれ」という指示を AI に与えます。これは学習させるときに『輪郭』の画像と『色付き』の画像の Pair を用意します。つまり下のように AI はある靴の『輪郭』の画像に色を付けたら右隣の靴の『色付き』画像になるという正解が与えられています。これは常に正解ラベルが用意されたデータが用意できる場合のみに機能します。(正確には『色付き』画像の輪郭抽出は簡単なので効率的にデータセットを用意できる)。

Pix2Pix

CycleGAN

 お待たせしました、ここでやっと本題に入ります。(ここまで読んでいただいた方は疑問点もあると思いますが、多分ここですっきりします。)
 馬をシマウマにしたい!!よくわからない願望ですが、やってみましょう。それでは今までのようにデータと正解を用意して…、あれ?つまり下の図のように姿勢が同じデータを用意する必要があるぞ。Photoshop を開いて、馬の部分に縞模様を付けて書き出して….これを何百枚、何千枚とやれば….うーん、無理!!となりますね。つまり今回の場合は、上で紹介した靴のように正解がなく、Unpair な画像しか存在しません。そこで『教師なし学習』である CycleGAN の登場です。
 CycleGAN では馬の画像のまとまり (馬ドメインと呼びます) とシマウマの画像のまとまり (シマウマドメインと呼びます) 間でそれぞれの特性を学習し、馬からシマウマ、シマウマから馬を生成できるようにします。

もしPairを用意するとしたらこのような感じ

順方法 (馬→シマウマ) の生成ネットワーク

 馬からシマウマを生成するネットワークを見ていきます。

馬→シマウマ

 このように馬ドメイン X の中にある馬の画像からシマウマドメイン Y のなかにあるシマウマの画像を生成(に変換)する Generator G とその作られた偽シマウマの真偽を判断する Discriminator D_Y があります。

  •  Generator G は頑張ってシマウマの偽画像を作り Discriminator D_Y をだまそうと学習します。

  •  対してDiscriminator D_Y は Generator G の作った巧妙なシマウマ画像に騙されないように学習します。

この 2 つが競い合って学習しています。これは GAN の特性ですね。
 ちなみに ” を生成(に変換)" と書いているのは CycleGAN の Generator は無から画像を生成しているわけではなく元の画像をベースに変換しているからです。(以下同様の理由)

逆方向 (シマウマ→馬) の生成ネットワーク

 逆にシマウマから馬を生成するネットワークを見ていきます。

シマウマ→馬

 このようにシマウマドメイン Y のなかにあるシマウマの画像から馬ドメイン X の中にある馬の画像を生成(に変換)する (逆)Generator F とその作られた偽馬の真偽を判断する Discriminator D_X があります。これも同様に

  •  Generator F は頑張って馬の偽画像を作り Discriminator D_X をだまそうと学習します。 

  • 対してDiscriminator D_X は Generator F の作った巧妙な馬画像に騙されないように学習します。

この 2 つが競い合って学習しています。

全体のネットワーク

 はい、ではこの二つを連結します。これが CycleGAN のネットワークです。2 つの Generator があり、2 つの Discriminator があります。そして矢印を見ると X から出て、Y になり、X に戻っています。これが "Cycle" GAN と呼ばれている理由です。

CycleGANのネットワーク

損失関数

 では Cycle させることが分かったところで Cycle させると何がいいのでしょう。それを解明するためには損失関数を見る必要があります。少し数式が出てきますが、もしわからなければ数式の理解は後ででいいので、分だけ読んでいってください。

サイクル一貫性損失 (L_cycle)

サイクル一貫性損失 (L_cycle)

 いきなり数式で申し訳ないです…。意味を簡単に解説します。
 まずはこの式の第一項 (+ より前) の説明です。馬の画像 X をシマウマの画像を生成(に変換)する Generator G を通してシマウマに(変換)します(G(X))。それを馬の画像を生成(に変換)する (逆)Generator F  に突っ込みます。するとまた馬の画像 F(G(X)) になって出てきます。もし 2 つの Generator がものすごい高精度であれば、この画像 F(G(X)) は元の入力画像とほぼ変わらず、一致しているはずですね。というわけで Cycle を一周させたときに元の画像との変化がなるべく少なくなるように Generator を学習させます。
 逆 ( + より後)もしかりです。この入力シマウマ画像とシマウマ→馬→シマウマと Cycle させた画像がほぼ変わらないとだめですね。この時、画像間の変化を L1 正則化で表現していてこれをなるべく小さくしようとしているのです。

自己同一性損失 (L_idt)

自己同一性損失 (L_idt)

 この式の第一項 (+ より前) の説明です。シマウマ生成器にシマウマを入れたときにそれを変化させてしまっては困りますよね。本来、 Generator G は馬からシマウマを生成(に変換)する生成器でした。そこにシマウマの画像 Y を入力しています。なのでこの損失を導入しシマウマ生成器にシマウマを入れた時に変化させないようにしています。
 同様に第二項 ( + より後) を見てみると、本来、 Generator F はシマウマから馬を生成(に変換)する生成器です。そこに馬画像を突っ込んでいるのでここでも大きな変化が起きてしまっては困ります。そこでこの自己同一性損失が役に立ってきます。
 第一項を見ると、シマウマ生成器 Generator G にシマウマ画像を入力したときに、元のシマウマ画像との L1 正則化をしています。つまりこの二つの距離 (類似度) を表しています。後半も同様ですね、馬生成器 Generator F に馬画像を入力したときに、元の馬画像との L1 正則化をしています。これを最小化するのがこの自己同一性損失です。
 この損失を考慮すると「変更しなくてもいい部分は変更しない」ということを 2 つの Generator が学習し、背景などのあまり変わってほしくない部分の変更が少なくなります。

敵対的損失 (L_GAN)

敵対的損失 (L_GAN)

  これはGANの解説で説明しました。簡単に説明すると Generator と Discriminator がお互いをだましあうようにする損失です。上が馬→シマウマ、下がシマウマ→馬の方向の損失関数です。これら2つ足し合わせて敵対的損失 (L_GAN) と呼びます。詳しくはこちらをどうぞ。

全体の損失(L_CycleGAN)

全体の損失(L_CycleGAN)

 最後にそのすべての損失に重みを付けて足し合わせれば、全体の損失(L_CycleGAN) となります。(論文ではこの損失に名前がついていなかったので今回、(L_CycleGAN) としました。)

実際に使ってみた

 それでは最後にCycleGANを使って実際に馬をシマウマにしてみたいと思います使ってみたいと思います。

シマウマ→馬の実験結果

 確かにシマウマが馬になっていますね。
 でもやはり真ん中の画像とかを見るとうまくいってませんね、これは馬が小さすぎたため、うまくいかなかったと考えられます。
 そして一番右の画像では、入力は馬が3頭いますが、出力されたシマウマはつながってしまっていますね。 AI はこれが馬であることは理解してるかもしれませんが、どこからどこまでが馬の領域なのか、何頭いるのかというのは理解していません。そもそも AI は画像のピクセルごとの数値とその周りのピクセルとの関係しか見ていないのです。

ついでにシマウマから馬生成もこんな感じです。

シマウマ→馬

CycleGAN の他の利用例

 馬→シマウマ以外にもこんなことをしています。

紅葉させる

緑葉↔紅葉

ミカン ↔ リンゴ

ミカン↔リンゴ

モンスターボール ↔ マスターボール

モンスターボール ↔ マスターボール

犬 ↔ 猫

最後に

 そして最後に CycleGAN を改良して軽く、より早くしてみます。と思ったのですが、さすがに5000文字を超えそうなので今回はここまでにして、気が向いたら CycleGAN の改良について
Teachers Do More Than Teach: Compressing Image-to-Image Models
こちらの論文を解説します。
 最後まで読んでいただきありがとうございました。最後に少し宣伝です。主のteftefが運営を行っているdiscordサーバーを載せます。このサーバーではMidjourneyやStble Diffusionのプロンプトを共有したり、研究したりしています。ぜひ参加してお絵描きAIを探ってみてはいかがでしょう。(teftef)


いいなと思ったら応援しよう!