見出し画像

自動微分 と dual number

先日大学で自動微分について講義を受けて,個人的にかなり面白いなと思ったので,学んだことや調べたことや考えたことをまとめておこうと思います.

自分は少し話を聞いただけの素人なので,書いていることが間違っていたりするかもしれません.お気づきのことがあれば指摘していただけるとありがたいです.

ボトムアップ型の自動微分について書きます.

自動微分とは

コンピュータで微分を計算するにはどうすればいいでしょうか.微分というのは定義に戻れば極限の計算であり,コンピュータはこのような極限の計算はあまり得意としていないようです.そのため以前は(今も?)数値微分という,極限の計算を近似的に行う手法で微分を計算するのが主流だったようです.微分は

$$
f'(a) = \lim_{h \to 0} \frac{f(a + h) - f(a)}{h}
$$

なので例えば,$${h = 0.00001}$$などとして

$$
\frac{f(a + 0.00001) - f(a)}{0.00001}
$$

という値を計算すれば実際の微分値$${f'(a)}$$に近い値が得られそうだというのは納得できるでしょう.このように極限を計算するのではなく,0に近い値を$${h}$$に代入して近似的に微分を計算するのが数値微分です.

しかし,この数値微分はコンピュータで計算してるが故に誤差が発生しやすいという弱点を持っています.(値の近い2数の引き算$${f(a+h) - f(a)}$$による桁落ちなども発生します.)
さらに言えば,あくまでも近似的に計算しているので多項式のような簡単な関数であっても正確な微分値を計算することはできません.

そこで微分計算の代数的な側面に注目して計算しようとしているのが自動微分です.例えば多項式であれば$${\dfrac{d}{dx}x^n = nx^{n-1}}$$という微分公式があるので,この形式的な計算をコンピュータにさせれば正確に(もしくは精度の高い)微分の計算ができるのではないかということです.

dual number とは

こちらは完全に数学の話になるのですが dual number (二重数) というものがあります.これは複素数の類似のような感じで捉えることができます.複素数では$${i}$$という$${2}$$乗すると$${-1}$$になる記号を考えていました.一方で dual number では$${\varepsilon}$$という$${2}$$乗すると$${0}$$になる記号を導入します.dual number の計算のルールは複素数で$${i}$$を$${\varepsilon}$$に変えただけだと思えば大丈夫です.つまり,dual number というのは実数$${a,b}$$を用いて$${a+b\varepsilon}$$と表される数のことで,

$$
(a + b\varepsilon) \pm (c + d\varepsilon) = (a \pm c) + (b \pm d)\epsilon
$$

$$
\begin{align*}
(a + b\varepsilon) \times (c + d\varepsilon) &= ac + ad\varepsilon + bc \varepsilon + bd \varepsilon^2\\
&= ac + (ad + bc) \varepsilon\\
\end{align*}
$$

$$
\begin{align*}
\frac{a+b\varepsilon}{c + d\varepsilon} &= \frac{(a+b\varepsilon)(c-d\varepsilon)}{(c+d\varepsilon)(c-d\varepsilon)}\\
&= \frac{ac - ad\varepsilon + bc\varepsilon + bd\varepsilon^2}{c^2 - d^2\varepsilon^2}\\
&= \frac{a}{c} + \frac{bc - ad}{c^2}\varepsilon
\end{align*}
$$

というように計算します.$${\varepsilon^2 = 0}$$になるところだけ注意してください.

補足

 2乗して0になる数みたいなものを勝手に考えていいのかと疑問に思う方もいらっしゃるかもしれません.実は dual number は代数学の環論の枠組みの中でしっかりと正当化されていますので,ある程度環論がわかる方向けに少し補足を入れておきます.

$${\R}$$を係数とする多項式環$${\R[X]}$$を$${X^2}$$で生成されるイデアル$${(X^2)}$$で割った剰余環$${\R[X]/(X^2)}$$が今考えていた dual number 全体の集合になります.複素数は$${\R[X]/(X^2+1)}$$なので,まさに複素数の類似という感じがします.

dual number と 微分の関係

ではこの dual number は微分とどのような関係があるのでしょうか.それは先程の計算式で$${a = f(s), b = f'(s), c = g(s), d = g'(s)}$$と置き直してもう一度計算式を眺めてみると見えてきます.

$$
\{f(s)+f'(s)\varepsilon\} \pm \{g(s) + g'(s)\varepsilon\} = \{f(s) \pm g(s)\} + \{f'(s) \pm g'(s)\}\varepsilon
$$

$$
\{f(s)+f'(s)\varepsilon\} \times \{g(s) + g'(s)\varepsilon\} = f(s)g(s) + \{f(s)g'(s) + f'(s)g(s)\} \varepsilon
$$

$$
\frac{f(s)+g(s)\varepsilon}{g(s) + g(s)\varepsilon} = \frac{f(s)}{g(s)} + \frac{f'(s)g(s) - f(s)g'(s)}{\{g(s)\}^2}\varepsilon
$$

ここで計算後の$${\varepsilon}$$の係数に注目してみるとそれらは全て,それぞれの演算に対応した微分公式になっていることがわかると思います.なのでこの計算規則をうまく使えば,とりあえず有理関数については極限の計算をすることなく,我々が普段微分をするのと同じような方法でコンピュータに微分をさせることができそうです!

では次に多項式以外の関数$${\sin x, \cos x, e^x, \log x}$$に dual number を入れた場合,どのように計算すれば良いのでしょうか.これは複素関数の時と同じように考えれば良いです.つまり,テイラー展開を用いて値を定義することにします.例えば$${e^x}$$に対しては

$$
\begin{align*}
e^{a+b\varepsilon} &= 1 + \sum_{n=1}^{\infty} \frac{1}{n!}(a+b\varepsilon)^n\\
&= 1 + \sum_{n=1}^{\infty} \frac{1}{n!}\sum_{k=1}^n {}_nC_k a^{n-k}b^k\varepsilon^k\\
&= 1 + \sum_{n=1}^{\infty} \frac{1}{n!}(a^n + na^{n-1}b\varepsilon)\\
&= 1 + \sum_{n=1}^{\infty} \frac{a^n}{n!} +b\varepsilon \sum_{n=1}^{\infty} \frac{a^{n-1}}{(n-1)!}\\
&= e^a + be^a\varepsilon
\end{align*}
$$

となります.これと同じように考えれば,テイラー展開可能な関数$${f(x)}$$に対しては

$$
f(a+b\varepsilon) = f(a) + bf'(a)\varepsilon
$$

と計算できることがわかります.したがって関数$${f(x)}$$の$${x=a}$$における微分を計算するには,その関数に$${a+\varepsilon}$$を代入してみて,$${\varepsilon}$$の係数を確認すれば良いということになります.

これを使えば三角関数や指数・対数関数などが含まれている関数に対しても dual number を用いて微分が計算できます.

dual number を用いて微分を計算する

試しに具体的な関数を使って自動微分の過程を追ってみたいと思います.

$${f(x) = x^2 + 3x + e^{2x}}$$の$${x=1}$$での微分を求めてみましょう.

$$
\begin{align*}
f(1+\varepsilon) &=  (1+\varepsilon)^2 + 3(1+\varepsilon) + e^{2+2\varepsilon}\\
&= 1 + 2\varepsilon + \varepsilon^2 + 3 + 3\varepsilon + e^2 + 2e^2\varepsilon\\
&= 4 + e^2 + (5+2e^2)\varepsilon
\end{align*}
$$

となるので$${f'(1) = 5+2e^2}$$になるようです.念のために合っているか確認してみます.

$$
f'(x) = 2x + 3 + 2e^{2x}\\
\therefore f'(1) = 5+2e^2
$$

dual number を用いて正しく計算できていることがわかりました.

サンプルプログラム

最後に,実際にPythonで自動微分を実装してみたので紹介します.DualNumberクラスの定義が長いですが,中身はかなり単純です.多項式,三角関数,指数関数,対数関数と$${x^x}$$のような形の関数に対しての微分が計算できるようにしました.D関数が微分する関数で,値$${f'(s)}$$を計算したければD(f, s)とすることで実行できます.

# autodiff.py
import math


class DualNumber:
    def __init__(self, a=0, b=0):
        self.a = a
        self.b = b

    def __add__(self, x):
        if isinstance(x, DualNumber):
            return DualNumber(self.a + x.a, self.b + x.b)
        else:
            return DualNumber(self.a + x, self.b)

    def __radd__(self, x):
        if isinstance(x, DualNumber):
            return DualNumber(self.a + x.a, self.b + x.b)
        else:
            return DualNumber(self.a + x, self.b)

    def __sub__(self, x):
        if isinstance(x, DualNumber):
            return DualNumber(self.a - x.a, self.b - x.b)
        else:
            return DualNumber(self.a - x, self.b)

    def __rsub__(self, x):
        if isinstance(x, DualNumber):
            return DualNumber(x.a - self.a, x.b - self.b)
        else:
            return DualNumber(x - self.a, self.b)

    def __mul__(self, x):
        if isinstance(x, DualNumber):
            return DualNumber(self.a * x.a, self.a * x.b + self.b * x.a)
        else:
            return DualNumber(self.a * x, self.b * x)

    def __rmul__(self, x):
        if isinstance(x, DualNumber):
            return DualNumber(self.a * x.a, self.a * x.b + self.b * x.a)
        else:
            return DualNumber(self.a * x, self.b * x)

    def __truediv__(self, x):
        if isinstance(x, DualNumber):
            return DualNumber(self.a / x.a, (self.b * x.a - self.a * x.b) / (x.a)**2)
        else:
            return DualNumber(self.a / x, self.b / x)

    def __rtruediv__(self, x):
        if isinstance(x, DualNumber):
            return DualNumber(x.a / self.a, (x.b * self.a - x.a * self.b) / (self.a)**2)
        else:
            return DualNumber(x / self.a, - x * self.b / (self.a)**2)

    def __pow__(self, n):
        if isinstance(n, DualNumber):
            return DualNumber(self.a ** n.a, (n.b * math.log(self.a) + n.a * self.b / self.a) * (self.a)**(n.a))
        else:
            return DualNumber(self.a ** n, n * self.a ** (n-1) * self.b)

    def __rpow__(self, base):
        return DualNumber(base ** self.a, base ** self.a * math.log(base) * self.b)

    def __neg__(self):
        return DualNumber(-self.a, -self.b)

    def __str__(self):
        if self.b < 0:
            return f"{self.a}{self.b}ε"
        else:
            return f"{self.a}+{self.b}ε"


def exp(x):
    if isinstance(x, DualNumber):
        return DualNumber(math.exp(x.a), math.exp(x.a) * x.b)
    else:
        return DualNumber(math.exp(x))


def log(x, base=math.e):
    if isinstance(x, DualNumber):
        if base == math.e:
            return DualNumber(math.log(x.a), 1 / x.a * x.b)
        else:
            return DualNumber(math.log(x.a, base), 1 / (x.a * math.log(base)) * x.b)
    else:
        if base == math.e:
            return DualNumber(math.log(x))
        else:
            return DualNumber(math.log(x, base))


def sin(x):
    if isinstance(x, DualNumber):
        return DualNumber(math.sin(x.a), math.cos(x.a) * x.b)
    else:
        return DualNumber(math.sin(x))


def cos(x):
    if isinstance(x, DualNumber):
        return DualNumber(math.cos(x.a), -math.sin(x.a) * x.b)
    else:
        return DualNumber(math.cos(x))


def tan(x):
    if isinstance(x, DualNumber):
        return DualNumber(math.tan(x.a), 1 / (math.cos(x.a))**2 * x.b)
    else:
        return DualNumber(math.tan(x))


def D(f, s):
    return f(DualNumber(s, 1)).b

そして,これを用いて実際に計算してみたものが次です.

from autodiff import exp, log, sin, cos, tan, D

def f_1(x):
    return x**2 + 3 * x + exp(2*x)

def f_2(x):
    return 3**x - 3 * cos(2*x)

def f_3(x):
    return x**2 / tan(x)

def f_4(x):
    return log(x + (1+x**2)**(1/2))

def f_5(x):
    return sin(x)**(cos(x)**(1/2))

print("( x^2 + 3x + exp(2x) )'(1)     = ", D(f_1, 1))
print("( 3^x - 3cos(2x) )'(1)         = ", D(f_2, 1))
print("( x^2 / tan(x) )'(1)           = ", D(f_3, 1))
print("( log(x + sqrt(1+x^2)) )'(1)   = ", D(f_4, 1))
print("( sin(x)**sqrt((cos(x))) )'(1) = ", D(f_5, 1))

"""
出力
( x^2 + 3x + exp(2x) )'(1)     =  19.7781121978613
( 3^x - 3cos(2x) )'(1)         =  8.751621426958419
( x^2 / tan(x) )'(1)           =  -0.1280976955687304
( log(x + sqrt(1+x^2)) )'(1)   =  0.7071067811865476
( sin(x)**sqrt((cos(x))) )'(1) =  0.5027587059146197
"""

出てきた値はMathematicaで計算した値と一致することを確認しました.特に最後のf_5関数はかなり複雑ですが,全く問題なく計算することができました.

いいなと思ったら応援しよう!