東大生AI初心者の学習日誌 Day9「ベイズ線形回帰(2)」

こんにちは、じゅんペー(@jp_aiboom)です!

僕は現在東京大学の理系の二年生です。この連載では、AI初心者の僕が、「パターン認識と機械学習(通称PRML)」を読み進めながら、機械学習の理論面を一から勉強していく様子をお届けしたいと思います。

第9回目の今回は、前回(第8回)の続きでベイズ線形回帰を扱っていきます。まだ前回の記事を読まれていない方は、是非そちらを先にお読みいただきたいです。

▶ 東大生AI初心者の学習日誌 Day8「ベイズ線形回帰(1)」

ベイズ線形回帰の目的(前回のおさらい)

ベイズ線形回帰を適用するシチュエーションについておさらいします。

前回までの記事で、上の図の各点(青点)を近似する曲線を「多項式曲線フィッティング」にて求めた(赤線)ところ、そこそこの精度は出ることが分かりました。

しかし、データが少ないので、近似の精度が高い場所とそうでない場所があると考えます。それぞれの精度を確率分布として表そうと思った際に、ベイズ線形回帰が必要になってくるのでした。

具体的には、以下のようなイメージです。


各xに対して、その上に確率分布を描くことで(この図では正規分布)、どこがどれくらい信頼できて、どれくらいの分散があるのかなどを視覚的に表すのが目標です。

ということで、今回は、この確率分布を表し方についての解説から始めていきます。

ベイズ線形回帰

では、ここから実際どうやって線形回帰に確率分布を導入するかみていきましょう。

まず、データ点一つ一つに対応する確率分布を以下の式で仮定します。
$$p(t|\boldsymbol{w},x)=N(t|\boldsymbol{w}^T\boldsymbol{\phi},\beta^{-1}) = \sqrt{\frac{\beta}{2\pi}}\exp(-\frac{1}{2}\beta(t-\boldsymbol{w}^T\boldsymbol{\phi}(x)))$$
この\(N(t|\boldsymbol{w}^T\boldsymbol{\phi}(x),\beta^{-1})\) は、平均が\(\boldsymbol{w}^T\boldsymbol{\phi}\) 、分散が\(\beta^{-1}\) である正規分布です。今回考えている問題では、予測した曲線から離れるほど確率が低くなっていくような分布になることが予想がつくので、妥当ですね。先に示した図でもこの通りになっています。

以後の話は、Day6Day7のベイズ推定の記事で解説した内容がかなり含まれており、慣れていない方には難しい可能性があるので、まだ読まれていない方は先にお読みいただきたいです。

上の式の中の\(\beta^{-1}\) は、正規分布において分散ですが、\(\beta\) が大きいと(\(\beta^{-1}\) は小さくなるので)平均付近に集まる確率が高くなり、逆に小さいと幅広くなります。このような意味から、精度とも呼ばれます。

さて、上の式は、データ点一つ一つに対する確率分布でした。ただ、実際にはもっとたくさんデータが与えられるので、例えばデータ、X=(x1,……,xN),T=(t1,……,tN)が得られたときの確率分布も欲しいです。それを表すのが以下の式です。
$$p(T|\boldsymbol{w},X) = p(t_1,t_2,…,t_n|\boldsymbol{w},x_1,x_2,…,x_n) = \prod_{i=1}^n p(t_i|\boldsymbol{w},x_i)$$

ここでは、データが全て独立に同じ分布から得られた時には、単純に掛け算をすればいいという原則を使っています。

また、今回はベイズ推定の考え方を用いるので、事前分布を仮定します。今回の場合で言うと、まだ一つもデータをとってない時の\(\boldsymbol{w}\)の分布です。

Day6の記事と同様に、今回も共役事前分布を選びたいので、 \(\boldsymbol{w}\)の事前分布も正規分布を仮定します。
具体的には、
$$p(\boldsymbol{w}|\alpha) = N(\boldsymbol{w}|\boldsymbol{0},\alpha^{-1}\boldsymbol{I})$$
とおきます。

これらを使うとことで、 \(\boldsymbol{w}\)の事後分布が計算でき、(計算方法はDay6のベイズの定理参照)
$$p(\boldsymbol{w}|X,T) = N(\boldsymbol{w}|\boldsymbol{m}_n,\boldsymbol{S}_n)$$
このように表すことができます。
ここで、
$$\boldsymbol{m_n} = \beta \boldsymbol{S_n} \boldsymbol{\phi}^T \boldsymbol{t},\boldsymbol{S}_n = (\alpha \boldsymbol{I}+\beta\boldsymbol{\phi}^T\boldsymbol{\phi}) $$
です。
かなり複雑になってきましたが、結果が面白いので、難しい式は読み飛ばしてもらっても問題ありません。

これでめでたく\(\boldsymbol{w}\)の事後確率が求まりました。つまり、例えば10個分のデータが得られた後、\(\boldsymbol{w}\)がどんな感じになりやすいかの分布が得られたので、いよいよこれを使って、目標であったtの予測分布を得ることができます。

前回前々回等に出てきた周辺化をすると、以下のように計算ができ、
$$\begin{eqnarray*} P(t|X,T,x) &=& \int p(t|\boldsymbol{w},x)p(\boldsymbol{w}|X,T)dw \\ &=& N(t|\boldsymbol{m}_n^T\boldsymbol{\phi}(x),\sigma_n^2(x)) \\ \sigma_n^2(x)=\beta^{-1}+\boldsymbol{\phi}(x)^T\boldsymbol{S}_n\boldsymbol{\phi}(x) \end{eqnarray*}$$

事前分布に共役なものを選んだので、tの予測分布もちゃんと正規分布になりました。

これを図示するとどんなグラフが描かれるかというと、\(\boldsymbol{m}_n^T\boldsymbol{\phi}(x)\) で表される曲線によって元のsin関数が近似され、この近似関数の各xの値を中心に、分散\(\sigma_n^2(x)\) で分布する正規分布になります。

ベイズ線形回帰の実装

それでは、実際に実装してみます。

今回は、点の近くと、点が近くない部分で精度が異なってくることが分かりやすいように、以下の図のように8個の点データで試します。

これは意図的に、x=0.4からx=0.6の間に点がなく、x=0.7の次はx=0.85の位置に点がプロットされていて、他の点は0.1ずつプロットされています。

このデータに対して、
$$P(t|X,T,x) = N(t|\boldsymbol{m}_n^T\boldsymbol{\phi}(x),\sigma_n^2(x))$$
を図示してみると、以下のようになりました。

この図は、緑の点がデータ点、赤の曲線がsinの近似曲線です。いろんな色がついていて少し複雑ですが、この色の分布が表すのは、内側の白っぽい部分がもっとも確率が高い部分で、外側の色の部分にいくにつれて確率が下がっていき、一番外側の濃い紫の部分が最も確率が低いことを表しています。

上図の白い線のt軸上で詳しくみてみると、この色の分布は等高線のようなものなので、大体で正規分布のようなものを書いてみると、このような白い山ができます。

見方としては、この白いt軸上では、今は赤い線上の値で予測していますが、白のt軸上の白い二点間の幅のどこかになる確率もかなり高くて、さらにその外側の青い二点の間になる確率もそこそこあるということを表しています。

これを踏まえると、x=0.4からx=0.6の間あたりには、そもそも一番確率の高い白っぽい部分はなく、一応赤い線で予測はしているものの、実際にはそこまで精度が高くないことを表しています。

この結果は、冒頭で述べた「近似の精度が高い場所とそうでない場所がある」という予想合致します。直感的には正しそうなことが実際に可視化されて目に見えると面白いですね。

今回も最後まで読んでいただきありがとうございました。

連載「東大生AI初心者の学習日誌 」
Day1「機械学習の全体像」
Day2「多項式曲線フィッティング」
Day3「多項式曲線フィッティングと過学習」
Day4「過学習と正則化」
Day5 A判定で東大に落ちる確率は?計算してみた!(1)~最尤推定編~
Day6 A判定で東大に落ちる確率は?計算してみた!(2)~MAP推定編~
Day7 A判定で東大に落ちる確率は?計算してみた!(3)~ベイズ推定編~
Day8「ベイズ線形回帰(1)」
Day9「ベイズ線形回帰(2)」

業界から探す

PAGE TOP