作って遊ぶ機械学習。

~基礎的な確率モデルから最新の機械学習技術まで~

変分近似(Variational Approximation)の基本(3)

「作って遊ぶ」を題目として掲げておきながらまだ作っても遊んでもいなかったので、今回はそろそろ何か動くものを載せたいと思います。

さて、前回得られた変分近似のアルゴリズムを導出するための手引きを使って、今回は世界で一番簡単だと思われる2次元ガウス分布に対して近似推定をやってみたいと思います。*1

 

[必要な知識]

下記をさらっとだけ確認しておくといいです。

 

今回は2次元のガウス分布の近似推定を例として行いますが、実を言うと、多次元ガウス分布積分も解析的にできますしサンプリングも簡単にできるような単純な分布なので、近似分布をわざわざ求める意味は皆無です。しかし、この例は計算がとてもシンプルで変分近似の導出手順を説明しやすいのと、変分近似が「近似してしまっているもの」が何なのか明確化することができるので、基本を説明するには十分な例だと思います。

では、次のような2次元のガウス分布を変分近似を使って近似推定してみましょう。

\[ p(x_1,x_2|\mu_1,\mu_2, \Lambda) = \mathcal{N} \Bigl(\left[\begin{array}{r} x_1 \\ x_2 \end{array}\right] \bigg| \left[ \begin{array}{r} \mu_1 \\ \mu_2 \end{array} \right], \left[ \begin{array}{cc} \Lambda_{1,1} & \Lambda_{1,2} \\ \Lambda_{2,1} & \Lambda_{2,2} \end{array} \right]^{-1} \Bigr) \\ \propto exp\Bigl\{ \Bigl(\left[\begin{array}{r} x_1 \\ x_2 \end{array}\right] - \left[ \begin{array}{r} \mu_1 \\ \mu_2 \end{array} \right] \Bigr)' \left[ \begin{array}{cc} \Lambda_{1,1} & \Lambda_{1,2} \\ \Lambda_{2,1} & \Lambda_{2,2} \end{array} \right] \Bigl(\left[\begin{array}{r} x_1 \\ x_2 \end{array}\right] - \left[ \begin{array}{r} \mu_1 \\ \mu_2 \end{array} \right]\Bigr) \Bigr\} \]

$x_1$,$x_2$はガウス分布からサンプルされる確率変数です。$\mu_1$,$\mu_2$は平均値、$\Lambda$は$x$の精度行列*2で、2つとも今回は適当な値で固定されたパラメータです。

さて、これをある近似分布$q(x_1,x_2)$で推定しようと思います。2変数の確率分布を分解して推定しようとしているので、次のように2つの分布に分解するしか今回は選択がないです。

\[q(x_1, x_2) = q(x_1)q(x_2) \]

さて、前回紹介した公式

\[ \ln q(z_1) = \langle \ln p(z_1, z_2| x) \rangle_{q(z_2)} + c \]

を適用してみましょう。*3

\[ \ln q(x_1) = \langle \ln p(x_1, x_2|\mu_1,\mu_2, \Lambda) \rangle_{q(x_2)} + c \\ = -\frac{1}{2} \langle x_1^{2}\Lambda_{1,1} - 2 x_1 (\Lambda_{1,1} \mu_1 - \Lambda_{1,2}(x_2 - \mu_2)) \rangle + c \\ = -\frac{1}{2} \{x_1^{2}\Lambda_{1,1} -2 x_1 (\Lambda_{1,1} \mu_1 - \Lambda_{1,2}(\langle x_2 \rangle - \mu_2)) \} + c \]

求めたいのは$x_1$に関する確率分布です。なので$x_1$にだけ注目し、無関係な項をすべて定数$c$に吸収させてしまうのが計算上のポイントです。さらにブラケット$\langle \cdot \rangle$を使って表現した期待値計算ですが、ここでは$x_2$のみに適用してあげればOKで、$x_2$に無関係な項たちはブラケットをするりと抜け出すことができます。

さて、この式をよく見てみると、$x_1$の「上に凸の2次関数」になっていることがわかります。対数計算の結果が上に凸の2次関数になっているということは、この確率分布は1次元のガウス分布であることを表しています。したがってこの式から、平均と分散を求めてあげれば近似分布が求まります。*4

\[ q(x_1) = \mathcal{N}(x_1| m_1, \Lambda_{1,1}^{-1}) \]

ただし、

\[ m_1 = \langle x_1 \rangle = \mu_1 - \Lambda_{1,1}^{-1}\Lambda_{1,2}(\langle x_2 \rangle - \mu_2) \]

です。

$q(x_2)$の計算も同様に計算でき、下記のようになります。*5

\[ q(x_2) = \mathcal{N}(x_2| m_2, \Lambda_{2,2}^{-1}) \]

\[ m_2 = \langle x_2 \rangle = \mu_2 - \Lambda_{2,2}^{-1}\Lambda_{2,1}(\langle x_1 \rangle - \mu_1) \]

 

以上から、得られる疑似コードは次のようになります。

  1. $m_2$をランダムに初期化する。
  2. $m_1$を更新する。
  3. $m_2$を更新する。
  4. 以上、2と3を十分な回数まで繰り返す。

結果的に近似分布の精度は$\Lambda_{1,1}$と$\Lambda_{2,2}$のまま更新されず、平均値だけ更新されていくようなアルゴリズムになりましたね。

さて、これを実装して動かした結果が次のものです。

 

f:id:sammy-suyama:20160131135817g:plain

f:id:sammy-suyama:20160131135416p:plain

上図では、青い楕円が推定したい真のガウス分布で、赤い楕円が推定中の近似分布です(σ=1にあたるところで楕円を描いています)。繰り返し回数は50回にしています。ランダムな平均値から出発して、だいたい15回目くらいの更新で収束しているように見えますね。下図では対応する真の分布と近似分布との間のKL divergenceを繰り返しごとにプロットしてみました。

 

要点を少し挙げてみます。

1、近似分布が共分散を表現できていない

上図の赤い楕円(近似分布)は常に軸に平行になっており、絶対に斜め向きにはなってくれません。これは近似分布の独立性(分解)を仮定しているため、x1とx2の相関が表現できなくなっているためです。このように変分近似では、最初に独立性を仮定してしまった変数間の相関は取ることができません。

2、KL divergenceが単調に減少している

下図を見ればわかるように、真の事後分布と近似分布との間のKL divergenceが単調に減少していることが分かります。これは、毎回の更新でKL divergenceを最小化する方向に近似分布を修正しているからです。もしこれが増加する場合はバグなので、更新式やソースコードを見直してみる必要があります。

 

今回の実験に関するソースコードは時間があるときにGitHubに上げようと思います。

すっかり忘れていましたが、Juliaで実装したものをGitHubに公開しました。

MLBlog/demo_simpleVI.jl at master · sammy-suyama/MLBlog · GitHub

基本的にはJulia上で

julia> include("demo_simpleVI.jl")

と叩けば今回のような図が色々出てくるかと思います。

 

次回以降は、もうちょっと複雑な、でも現実的なモデルに対して変分近似を適用してみたいと思います。

 

[続き・関連]

MCMCと変分近似 - 作って遊ぶ機械学習。

*1:まったく同じ例がPRMLにもあります。世界で一番簡単な例なので許してください。

*2:精度行列は共分散行列の逆行列なのですが、今回の例では共分散行列を使うよりもこの精度行列の方が導出がスッキリします。

*3:前回導いた公式では右辺は観測データによって条件づけされていますが、今回はそれに該当するのはガウス分布のパラメータ($\mu$、$\Lambda$)です。ベイズ学習における推論では、既知の固定パラメータと観測データには数学的な区別はありません。

*4:高校数学で教わった平方完成を使います。2次式を$-\frac{1}{2}(x - m)' \Lambda (x - m)$と置いて展開してあげれば、逆からどのように平均と精度を求めたらいいかがわかるかと思います。

*5:実は2つの更新式を連立方程式として解くと2つの平均値$\langle x_1 \rangle$と$\langle x_2 \rangle$が解析に求まってしまいます。今回は実際の多くの応用のように、あえて繰り返しアルゴリズムを使って解を求めます。