Zygote Internals¶

Julia の関数 f に対してその微分が f' で得られる仕組みとその内部についての簡単な説明をする(簡単だとは言っていない)

Load packages¶

In [1]:
using Zygote
In [2]:
]st Zygote
Status `~/work/zygote_flux_tutorial/zygote_flux_tutorial/Project.toml`
  [e88e6eb3] Zygote v0.6.62

復習¶

二次関数の微分をすることを考える.

In [3]:
f(x) = x^2
Out[3]:
f (generic function with 1 method)

もちろん, 理論上微分結果は $2x$ となる. さて, これを Julia で実行するには次のように

f'

とすれば良い.

In [4]:
x = 3
@assert f'(x) == 2x == 6

f' の 'って何?¶

実は ' は Base.adjoint 関数の役割を担う.

ドキュメントレベルでは Punctuation に説明されている:

' a trailing apostrophe is the adjoint (that is, the complex transpose) operator Aᴴ

ソースコードレベルでは

Julia本体のリポジトリ base/operators.jl で確認できる:

const var"'" = adjoint

関数に対する adjoint(f) の振る舞いは Zygote.jl/src/compiler/interface.jl で次のように実装されている:

Base.adjoint(f::Function) = x -> gradient(f, x)[1]
In [5]:
x = 3
adjoint(f)(x) == f'(x)
Out[5]:
true

gradient and pullback¶

gradient 関数は大まかに次のようなことをしていると思って良い:

function gradient(f, x)
  y, back = pullback(f, x)
  return back(1)
end

厳密には Zygote.jl/src/compiler/interface.jl を見ること

y, back = pullback(f, x) の戻り値の y は f(x) のことである. back は関数である. これを説明するのに 記号を準備する.

変数 $x$ と関数 $f=f(x)$ が与えられているとする. いま, ある関数 $L = L(y)$ where $y=f(x)$ に対して

  • $L$ の $x$ に関する微分を $\displaystyle \bar{x} = \frac{dL}{dx}$,
  • $L$ の $y$ に関する微分を $\displaystyle \bar{y} = \frac{dL}{dy}$

と書くことにする. この時, 合成関数の連鎖律から

$$ \bar{x} = \frac{dL}{dx} = \frac{dL}{dy}\frac{dy}{dx} = \bar{y}\frac{dy}{dx} =: B_x(\bar{y}) $$

となる. 上述した back は $\bar{y}$ を入力とし $\bar{y}\frac{dy}{dx}$ を出力する関数 $B_x(\bar{y})$ になる. よって back(1) は $\frac{dy}{dx}$ つまり $f'(x)$ を計算していることになる. これは $L(y) = y$ と設定した状況で $\bar{x}$ を計算していることに他ならない.

  • 例えば $f(x)=\sin(x)$ の場合は gradient や pullback の振る舞いは次のようになっていると思えば良い: ただし,ここでは説明を簡単にするために実際の実装と戻り値の型を簡易にしているため厳密には異なる.
In [6]:
square(x) = x * x

pullback_for_sin(x) = (sin(x), ȳ -> ȳ * cos(x))
pullback_for_cos(x) = (cos(x), ȳ -> ȳ * -sin(x))
pullback_for_square(x) = (x * x, ȳ -> ȳ * 2x)

# 自前の pullback 関数
mypullback(::typeof(sin), x) = pullback_for_sin(x) # f=sin の場合の実装
mypullback(::typeof(cos), x) = pullback_for_cos(x) # f=cos の場合の実装
mypullback(::typeof(square), x) = pullback_for_square(x) # f=square の場合の実装

# 自前の gradient 関数
function mygradient(f, x)
    _, back = mypullback(f, x) # f によって振る舞いが異なる.
    return back(1)
end

x = π/6
@assert mygradient(sin, x) ≈ sin'(x) ≈ cos(x)
@assert mygradient(cos, x) ≈ cos'(x) ≈ -sin(x)
@assert mygradient(square, x) ≈ square'(x) ≈ 2x

Julia が提供しているさまざまな関数に対する pullback の実装は ChainRules.jl/ChainRulesCore.jl にある規則に基づいている.

三角関数の場合は, 例えばここ で見ることができる.

ちなみにタンジェント関数 tan に対する微分規則は @scalar_rule というマクロを用いて定義している.

@scalar_rule tan(x) 1 + Ω ^ 2

これは Ω = tan(x) という定義のもとで 1 + Ω² を微分値として返却する pullback を実装するマクロである.

$\tan'(x) = \frac{1}{\cos^2(x)} = 1 + \tan^2(x)$ という数学的性質を利用していることに注意.

Zygote.jl と ChainRules.jl の関係性については

  • JuliaCon2020 の講演 JuliaCon 2020 | ChainRules.jl | Lyndon White
  • JuliaCon2021 の講演 Everything you need to know about ChainRules 1.0 | Miha Zgubič | JuliaCon2021

が詳しい

In [7]:
x = π/6
tan'(x), 1/(cos(x))^2, 1 + (tan(x))^2
Out[7]:
(1.3333333333333333, 1.333333333333333, 1.3333333333333333)

合成関数の微分¶

  • back が関数であることは基本的な関数を合成した関数に対する微分を考える時に便利だからである.

$y=f(x)$, $z=g(y)$ の関係を持つ変数 $x, y, z$ と $L=L(z)$ があるとする.

この時 $\bar{x}=\frac{dL}{dx}$ は連鎖律によって次のように書ける:

$$ \begin{aligned} \bar{x} &= \frac{dL}{dx} = \frac{dL}{dy}\frac{dy}{dx} = \bar{y} \frac{dy}{dx} = B_x(\bar{y}) \\ \bar{y} &= \frac{dL}{dy} = \frac{dL}{dz}\frac{dz}{dy} = B_y(\bar{z}) \end{aligned} $$

いま $L=z$ とすれば $\frac{dz}{dx}$ は $(B_x\circ B_y)(\bar{z}=1)$ として計算できる. これが back が関数であることの所以である.

Example¶

$z(x) = \sin(x^2)$ というケースの微分を考える.

In [8]:
square(x) = x * x
f(x) = square(x) # x * x の事.
g(y) = sin(y)
# 合成関数を定義
z(x) = (g ∘ f)(x) # ∘ は \circ とタイプする. もちろん z(x) = g(f(x)) でも良い.

pullback_for_f(x) = (x * x, ȳ -> ȳ * 2x) # (f(x), B_x(ȳ)) みたいなやつ. 中身は pullback_for_square と同じ
pullback_for_g(x) = (sin(x), ȳ -> ȳ * cos(x)) # pullback_for_sin と同じ

mypullback(::typeof(f), x) = pullback_for_f(x)
mypullback(::typeof(g), x) = pullback_for_g(x)

# 合成関数 z に対する pullback の実装
function pullback_for_z(x)
    y, back_for_f = mypullback(f, x)
    z, back_for_g = mypullback(g, y)
    function back_for_z(z̄)
        ȳ = back_for_g(z̄)
        x̄ = back_for_f(ȳ)
        return x̄
    end
    return (z, back_for_z)
end

mypullback(::typeof(z), x) = pullback_for_z(x) # f=z の場合の実装

function mygradient(f, x)
    _, back = mypullback(f, x) # f によって振る舞いが異なる.
    return back(1)
end

x = π/6
@assert mygradient(z, x) ≈ (z)'(x) ≈ 2x * cos(x^2)

Zygote のお仕事¶

Zygote ではプログラマーが記述したコードを中間表現に変換しその中身を解析することで上記で行なったような mypullback 関数を作るお仕事をしている.

Zygote.@code_ir ¶

Zygote.jl ではプログラマーが書いた関数を Static Single Assignment form, SSA という中間表現(IR)に変換する. ここではプログラマーが記述したプログラムを Julia のコンパイラーにとって優しいものに変換する程度の理解で良い.

例えば多項式函数は次のように変換される.

julia> poly(x) = 3x*x + 2x + 1
julia> Zygote.@code_ir poly(1.)
1: (%1, %2) # %1 は関数 poly, %2 は引数 x
  %3 = 3 * %2 # まずは入力を 3倍する i.e. 3x を計算
  %4 = %3 * %2 # この時点で 3 x^2 を実現
  %5 = 2 * %2 # 入力を2倍にしたものを計算 i.e. 2x を計算
  %6 = %4 + %5 + 1 # 諸々を足し合わせる
  return %6

詳しいことは IRTools.jl の説明 を見ると良い.

In [9]:
poly(x) = 3x*x + 2x + 1
Zygote.@code_ir poly(1.)
Out[9]:
1: (%1, %2)
  %3 = 3 * %2
  %4 = %3 * %2
  %5 = 2 * %2
  %6 = %4 + %5 + 1
  return %6

先ほど定義した z についても変換ができる

In [10]:
Zygote.@code_ir z(1.)
Out[10]:
1: (%1, %2)
  %3 = Main.g ∘ Main.f
  %4 = (%3)(%2)
  return %4

Zygote.@code_adjoint¶

Zygote.@code_ir の結果を得て Zygote が生成する pullback は Zygote.@code_adjoint で確認することができる.

julia> Zygote.@code_adjoint poly(1)
Zygote.Adjoint(1: (%3, %4 :: Zygote.Context, %1, %2)
  %5 = Zygote._pullback(%4, Main.:*, 3, %2) 
  %6 = Base.getindex(%5, 1)
  %7 = Base.getindex(%5, 2)
  %8 = Zygote._pullback(%4, Main.:*, %6, %2)
  %9 = Base.getindex(%8, 1)
  %10 = Base.getindex(%8, 2)
  %11 = Zygote._pullback(%4, Main.:*, 2, %2)
  %12 = Base.getindex(%11, 1)
  %13 = Base.getindex(%11, 2)
  %14 = Zygote._pullback(%4, Main.:+, %9, %12, 1)
  %15 = Base.getindex(%14, 1)
  %16 = Base.getindex(%14, 2)
  return %15, 1: (%1)
  %2 = (@16)(%1)
  %3 = Zygote.gradindex(%2, 2)
  %4 = Zygote.gradindex(%2, 3)
  %5 = (@13)(%4)
  %6 = Zygote.gradindex(%5, 3)
  %7 = (@10)(%3)
  %8 = Zygote.gradindex(%7, 2)
  %9 = Zygote.gradindex(%7, 3)
  %10 = (@7)(%8)
  %11 = Zygote.gradindex(%10, 3)
  %12 = Zygote.accum(%6, %9, %11)
  %13 = Zygote.tuple(nothing, %12)
  return %13)
  • よーく(ベニマル)目を凝らすと上記で定義した pullback_for_z にて定義した表記と似ていることがわかる. 実際,
  %5 = Zygote._pullback(%4, Main.:*, 3, %2) 
  %6 = Base.getindex(%5, 1)
  %7 = Base.getindex(%5, 2)

は平たくいうと

mul(x, y) = x * y # multiply
mypullback_for_mul(x, y) = (x * y, Δ -> (Δ *y, Δ * x))
mypullback(::typeof(mul), x, y) = mypullback_for_mul(x, y)

という設定のもとで

p5 = mypullback(mul, 3, x)
p6 = p5[1]
p5 = p5[2]

をしていると思えば良い. 配列 X に対する添字 i へのアクセスは Base.getindex(X, i) と書ける.

上の3 行を 1 行で書くと

out, back = mypullback(mul, 3, x)

となる. 一方で,

return %15, 1: (%1)
  %2 = (@16)(%1)
  %3 = Zygote.gradindex(%2, 2)
  %4 = Zygote.gradindex(%2, 3)
  %5 = (@13)(%4)
  %6 = Zygote.gradindex(%5, 3)
  %7 = (@10)(%3)
  %8 = Zygote.gradindex(%7, 2)
  %9 = Zygote.gradindex(%7, 3)
  %10 = (@7)(%8)
  %11 = Zygote.gradindex(%10, 3)
  %12 = Zygote.accum(%6, %9, %11)
  %13 = Zygote.tuple(nothing, %12)

の部分は back_for_z の実装相当のことしていると思えばよい. ここでは, Zygote.gradindex は Bsae.getindex と同じものという理解で良い.

これで pullback_for_z にて定義した表記と似ているということが分かったと思う.

さて, これらから

Zygote.@code_adjoint poly(1)

の結果から pullback_for_poly 関数を定義することが出そうである.

In [11]:
poly(x) = 3x*x + 2x + 1

add(x, y, z) = x + y + z
mul(x, y) = x * y 

pullback_for_add(x, y, z) = (add(x, y, z), Δ -> (Δ, Δ, Δ))
pullback_for_mul(x, y) = (mul(x, y), Δ -> (Δ * y, Δ * x)) # 第二成分は (∂_x mul, ∂_y mul) と思えば良い

mypullback(::typeof(add), x, y, z) = pullback_for_add(x, y, z)
mypullback(::typeof(mul), x, y) = pullback_for_mul(x, y)

function pullback_for_poly(x)
    # %2 は x である

    # %5 = Zygote._pullback(%4, Main.:*, 3, %2)
    # %6 = Base.getindex(%5, 1)
    # %7 = Base.getindex(%5, 2)
    out6, back7 = mypullback(mul, 3, x)
    # %8 = Zygote._pullback(%4, Main.:*, %6, %2)
    # %9 = Base.getindex(%8, 1)
    # %10 = Base.getindex(%8, 2)
    out9, back10 = mypullback(mul, out6, x)
    # %11 = Zygote._pullback(%4, Main.:*, 2, %2)
    # %12 = Base.getindex(%11, 1)
    # %13 = Base.getindex(%11, 2)
    out12, back13 = mypullback(mul, 2, x)
    # %14 = Zygote._pullback(%4, Main.:+, %9, %12, 1)
    # %15 = Base.getindex(%14, 1)
    # %16 = Base.getindex(%14, 2)
    out15, back16 = mypullback(add, out9, out12, 1)
    function back_for_poly(ā)
        # %1 は ā のことである
        # gradindex(∘, 1) は形式的なオブジェクトが入っているので意味のある出力は gradindex(∘,2), gradindex(∘, 3) などである
        # (@13) などは back13 などと対応する
        
        # %2 = (@16)(%1)
        # %3 = Zygote.gradindex(%2, 2)
        # %4 = Zygote.gradindex(%2, 3)
        x̄3, ȳ4 = back16(ā)
        # %5 = (@13)(%4)
        # %6 = Zygote.gradindex(%5, 3)
        _, ȳ6 = back13(ȳ4)
        # %7 = back10(%3)
        # %8 = Zygote.gradindex(%7, 2)
        # %9 = Zygote.gradindex(%7, 3)
        x̄8, ȳ9 = back10(x̄3)
        # %10 = (@7)(%8)
        # %11 = Zygote.gradindex(%10, 3)
        _, ȳ11 = back7(x̄8)
        # %12 = Zygote.accum(%6, %9, %11)
        # %13 = Zygote.tuple(nothing, %12)
        ō12 = sum([ȳ6, ȳ9, ȳ11]) # accum は平たくいうと sum と同じ
        return ō12 # 値はスカラーとして返したいのでここでは tuple に変換はしないでおく
    end
    return out15, back_for_poly
end

mypullback(::typeof(poly), x) = pullback_for_poly(x)

x = rand()
@assert mygradient(poly, x) ≈ (poly)'(x) ≈ 6x + 2
  • できたんジェントシータピオカ!(できた!)

まとめ¶

  • Zygote.jl の内部でやっていることの説明についてのべた

このノートブックで行っていることがわかれば https://fluxml.ai/Zygote.jl/latest/ のドキュメントも理解できるようになるはずだ. 意欲のある方は https://github.com/FluxML/Zygote.jl のソースを追いかけると幸せになれるかもしれない.