通常在研究贝叶斯模型中,很多情况下我们关注的是如何求解后验概率(Posterior),不幸的是,在实际模型中我们很难通过简单的贝叶斯理论求得后验概率的公式解,但是这并不影响我们对贝叶斯模型的爱——既然无法求得精确解,来个近似解在实际中也是可以接受的:-)。一般根据近似解的求解方式可以分为随机(Stochastic)近似方法(代表是MCMC,在上一篇中我们提到的利用Gibbs Sampling训练LDA的模型便是一种),另外一种确定性(Deterministic)近似方法。本篇要介绍的变分推断便属于后者,一般情况下确定性近似方法会比随机近似方法更快和更容易判断收敛。变分贝叶斯推断是一种求解框架,类似于EM算法,在求解概率模型中有很广泛的运用,是纵横江湖不可或缺的利器:-)。本篇试图从理论上简单介绍这个方法,做到从精神上领会其脉络,具体应用例子我会给出链接,读者可以自己过过瘾:-)。
变分法(Calculus of variations)
对于普通的函数 f(x) f ( x ) <script type="math/tex" id="MathJax-Element-3">f(x)</script>,我们可以认为 f f <script type="math/tex" id="MathJax-Element-4">f</script>是一个关于x
<script type="math/tex" id="MathJax-Element-5">x</script>的一个实数算子,其作用是将实数 x x <script type="math/tex" id="MathJax-Element-6">x</script>映射到实数f(x)
<script type="math/tex" id="MathJax-Element-7">f(x)</script>,那么可以类比这种模式,假设存在函数算子 F F <script type="math/tex" id="MathJax-Element-8">F</script>,它是关于f(x)
<script type="math/tex" id="MathJax-Element-9">f(x)</script>的函数算子,可以将 f(x) f ( x ) <script type="math/tex" id="MathJax-Element-10">f(x)</script>映射成实数 F(f(x)) F ( f ( x ) ) <script type="math/tex" id="MathJax-Element-11">F(f(x))</script>,在机器学习中,常见的函数算子有信息熵 H(p(x)) H ( p ( x ) ) <script type="math/tex" id="MathJax-Element-12">H(p(x))</script>,它将概率密度函数 p(x) p ( x ) <script type="math/tex" id="MathJax-Element-13">p(x)</script>映射成一个具体值。用贴近程序语言的说法就是在变分法中,我们研究的对象是高阶函数,它接受一个函数作为参数,并返回一个值。
在求解函数 f(x) f ( x ) <script type="math/tex" id="MathJax-Element-14">f(x)</script>极值的时候,我们利用微分法,假设它存在极小值 x0 x 0 <script type="math/tex" id="MathJax-Element-15">x_0</script>,那么其导数 f′(x0)=0 f ′ ( x 0 ) = 0 <script type="math/tex" id="MathJax-Element-16">f'(x_0)=0</script>,并且对于任意接近0的数 ϵ ϵ <script type="math/tex" id="MathJax-Element-17">\epsilon</script>有:
f(x0)≤f(x0+ϵ) f ( x 0 ) ≤ f ( x 0 + ϵ )
<script type="math/tex; mode=display" id="MathJax-Element-18">f(x_0)\le f(x_0+\epsilon)</script>如果定义函数
Φ(ϵ)=f(x0+ϵ) Φ ( ϵ ) = f ( x 0 + ϵ ) <script type="math/tex" id="MathJax-Element-19">\Phi(\epsilon)=f(x_0+\epsilon)</script>,那么另一种说明函数
f f <script type="math/tex" id="MathJax-Element-20">f</script>在
x0
<script type="math/tex" id="MathJax-Element-21">x_0</script>处取得极值的说法就是:
Φ′(0)=dΦ(ϵ)dϵ∣∣∣ϵ=0=f′(x0+0)=f′(x0)=0 Φ ′ ( 0 ) = d Φ ( ϵ ) d ϵ | ϵ = 0 = f ′ ( x 0 + 0 ) = f ′ ( x 0 ) = 0
<script type="math/tex; mode=display" id="MathJax-Element-22">\Phi'(0) = \frac{d \Phi(\epsilon)}{d \epsilon} \bigg |_{\epsilon=0}=f'(x_0+0)=f'(x_0)=0</script>那么,类似的,如何求解一个高阶函数的极值呢?以下来源于
维基百科,由于本人对泛涵了解甚少,遂翻译之以供参考:-)。
考察如下函数算子:
J(y)=∫x2x1L(y(x),y′(x),x)dx J ( y ) = ∫ x 1 x 2 L ( y ( x ) , y ′ ( x ) , x ) d x
<script type="math/tex; mode=display" id="MathJax-Element-23">J(y)=\int_{x_1}^{x_2}L(y(x), y'(x), x) dx</script>
其中
x1,x2 x 1 , x 2 <script type="math/tex" id="MathJax-Element-24">x_1,x_2</script>为常数,
y(x) y ( x ) <script type="math/tex" id="MathJax-Element-25">y(x)</script>是连续二阶可导,
L L <script type="math/tex" id="MathJax-Element-26">L</script>对于
y,y′,x
<script type="math/tex" id="MathJax-Element-27">y,y',x</script>也是连续二阶可导。
假设该函数算子
J(y) J ( y ) <script type="math/tex" id="MathJax-Element-28">J(y)</script>在
y=f y = f <script type="math/tex" id="MathJax-Element-29">y=f</script>时存在极小值,那么对于任意函数
η η <script type="math/tex" id="MathJax-Element-30">\eta</script>,只要其满足
η(x1)=0 η ( x 1 ) = 0 <script type="math/tex" id="MathJax-Element-31">\eta(x_1)=0</script>且
η(x2)=0 η ( x 2 ) = 0 <script type="math/tex" id="MathJax-Element-32">\eta(x_2)=0</script>,那么对于任意小的
ϵ ϵ <script type="math/tex" id="MathJax-Element-33">\epsilon</script>如下不等式成立:
J(f)≤J(f+ϵη) J ( f ) ≤ J ( f + ϵ η )
<script type="math/tex; mode=display" id="MathJax-Element-34">J(f) \le J(f+\epsilon \eta)</script>,其中
ϵη ϵ η <script type="math/tex" id="MathJax-Element-35">\epsilon \eta</script>称为函数
f f <script type="math/tex" id="MathJax-Element-36">f</script>的变分,记为
δf
<script type="math/tex" id="MathJax-Element-37">\delta f</script>。
考察函数
Φ(ϵ)=J(f+ϵη) Φ ( ϵ ) = J ( f + ϵ η )
<script type="math/tex; mode=display" id="MathJax-Element-38">\Phi(\epsilon) = J(f + \epsilon \eta)</script>,同样的有:
Φ′(0)=dΦ(ϵ)dϵ∣∣∣ϵ=0=dJ(f+ϵη)dϵ∣∣∣ϵ=0=∫x2x1dLdϵ∣∣∣ϵ=0dx=0 Φ ′ ( 0 ) = d Φ ( ϵ ) d ϵ | ϵ = 0 = d J ( f + ϵ η ) d ϵ | ϵ = 0 = ∫ x 1 x 2 d L d ϵ | ϵ = 0 d x = 0
<script type="math/tex; mode=display" id="MathJax-Element-39">\Phi'(0)=\frac{d \Phi(\epsilon)}{d \epsilon}\bigg |_{\epsilon=0}=\frac{d J(f+\epsilon \eta)}{d\epsilon}\bigg |_{\epsilon=0}\\=\int _{x_1}^{x_2} \frac{d L}{d \epsilon} \bigg |_{\epsilon=0} dx = 0</script>其中
dLdϵ=∂L∂y∂y∂ϵ+∂L∂y′∂y′∂ϵ d L d ϵ = ∂ L ∂ y ∂ y ∂ ϵ + ∂ L ∂ y ′ ∂ y ′ ∂ ϵ
<script type="math/tex; mode=display" id="MathJax-Element-40">\frac{d L}{d \epsilon} =\frac{\partial L}{\partial y}\frac{\partial y}{\partial \epsilon} + \frac{\partial L}{\partial y'}\frac{\partial y'}{\partial \epsilon}</script>
又因为
y=f+ϵη y = f + ϵ η <script type="math/tex" id="MathJax-Element-41">y = f+\epsilon \eta</script>,
y′=f′+ϵη′ y ′ = f ′ + ϵ η ′ <script type="math/tex" id="MathJax-Element-42">y' =f' + \epsilon \eta'</script>,因此
∂y∂ϵ=η∂y′∂ϵ=η′ ∂ y ∂ ϵ = η ∂ y ′ ∂ ϵ = η ′
<script type="math/tex; mode=display" id="MathJax-Element-43">\frac{\partial y}{\partial \epsilon}=\eta \\\frac{\partial y'}{\partial \epsilon}=\eta' </script>代入可得:
dLdϵ=∂L∂yη+∂L∂y′η′ d L d ϵ = ∂ L ∂ y η + ∂ L ∂ y ′ η ′
<script type="math/tex; mode=display" id="MathJax-Element-44">\frac{d L}{d \epsilon} =\frac{\partial L}{\partial y}\eta + \frac{\partial L}{\partial y'}\eta'</script>,再根据
分部积分法可得:
∫x2x1dLdϵ∣∣∣ϵ=0dx=∫x2x1{∂L∂yη+∂L∂y′η′}∣∣∣ϵ=0dx=∫x2x1η{∂L∂f−ddx∂L∂f′}dx+∂L∂f′η∣∣∣x2x1=∫x2x1η{∂L∂f−ddx∂L∂f′}dx=0 ∫ x 1 x 2 d L d ϵ | ϵ = 0 d x = ∫ x 1 x 2 { ∂ L ∂ y η + ∂ L ∂ y ′ η ′ } | ϵ = 0 d x = ∫ x 1 x 2 η { ∂ L ∂ f − d d x ∂ L ∂ f ′ } d x + ∂ L ∂ f ′ η | x 1 x 2 = ∫ x 1 x 2 η { ∂ L ∂ f − d d x ∂ L ∂ f ′ } d x = 0
<script type="math/tex; mode=display" id="MathJax-Element-45">\int _{x_1}^{x_2} \frac{d L}{d \epsilon} \bigg |_{\epsilon=0} dx =\int _{x_1}^{x_2} \bigg \{ \frac{\partial L}{\partial y}\eta + \frac{\partial L}{\partial y'}\eta'\bigg \} \bigg|_{\epsilon=0}dx \\= \int _{x_1}^{x_2} \eta\bigg \{ \frac{\partial L}{\partial f} - \frac{d}{dx} \frac{\partial L}{\partial f'}\bigg \} dx + \frac{\partial L}{\partial f'}\eta\bigg|_{x_1}^{x_2}\\= \int _{x_1}^{x_2} \eta\bigg \{ \frac{\partial L}{\partial f} - \frac{d}{dx} \frac{\partial L}{\partial f'}\bigg \} dx =0 </script>因为当
ϵ=0 ϵ = 0 <script type="math/tex" id="MathJax-Element-46">\epsilon=0</script>时,
y−>f y − > f <script type="math/tex" id="MathJax-Element-47">y->f</script>,
y′−>f′ y ′ − > f ′ <script type="math/tex" id="MathJax-Element-48">y'->f'</script>,又由于
η η <script type="math/tex" id="MathJax-Element-49">\eta</script>在
x1,x2 x 1 , x 2 <script type="math/tex" id="MathJax-Element-50">x_1,x_2</script>取值为0,所以
∂L∂f′η∣∣∣x2x1=0 ∂ L ∂ f ′ η | x 1 x 2 = 0 <script type="math/tex" id="MathJax-Element-51">\frac{\partial L}{\partial f'}\eta\bigg|_{x_1}^{x_2}=0</script>。最后根据
变分法基本引理,我们最终可得
欧拉-拉格朗日方程:
∂L∂f−ddx∂L∂f′=0 ∂ L ∂ f − d d x ∂ L ∂ f ′ = 0
<script type="math/tex; mode=display" id="MathJax-Element-52">\frac{\partial L}{\partial f} - \frac{d}{dx} \frac{\partial L}{\partial f'}=0</script>需要注意的是,欧拉-拉格朗日方程只是函数算子取得极值的必要条件,而不是充分条件,但是我们总算是对于变分法有一定的初步认识,也知道其中一个求函数算子极值的方法。
平均场定理(Mean Field Theory)
对,变分推断确实和这个定理有莫大联系,有必要稍微了解一下,免得后面一脸茫然。先看看维基百科的介绍:
In physics and probability theory, mean field theory (MFT also known as self-consistent field theory) studies the behavior of large and complex stochastic models by studying a simpler model. Such models consider a large number of small individual components which interact with each other. The effect of all the other individuals on any given individual is approximated by a single averaged effect, thus reducing a many-body problem to a one-body problem.
很遗憾,笔者对平均场理论理解无法做出直观解释,期待读者解惑。总而言之,平均场理论是用于简化复杂模型的理论,譬如对于一个概率模型:
P(x1,x2,x3,...,xn)=P(x1)P(x2|x1)P(x3|x2,x1)...P(xn|xn−1,xn−2,xn−3,...,x1) P ( x 1 , x 2 , x 3 , . . . , x n ) = P ( x 1 ) P ( x 2 | x 1 ) P ( x 3 | x 2 , x 1 ) . . . P ( x n | x n − 1 , x n − 2 , x n − 3 , . . . , x 1 )
<script type="math/tex; mode=display" id="MathJax-Element-53">P(x_1,x_2,x_3,...,x_n)\\=P(x_1)P(x_2|x1)P(x_3|x_2,x_1)...P(x_n|x_{n-1},x_{n-2},x_{n-3},...,x_1)</script> 利用平均场理论我们可以找出另一个模型:
Q(x1,x2,x3,...,xn)=Q(x1)Q(x2)Q(x3)...Q(xn) Q ( x 1 , x 2 , x 3 , . . . , x n ) = Q ( x 1 ) Q ( x 2 ) Q ( x 3 ) . . . Q ( x n )
<script type="math/tex; mode=display" id="MathJax-Element-54">Q(x_1,x_2,x_3,...,x_n)=Q(x_1)Q(x_2)Q(x_3)...Q(x_n)</script>
使得
Q Q <script type="math/tex" id="MathJax-Element-55">Q</script>尽量和
P
<script type="math/tex" id="MathJax-Element-56">P</script>一致,并可以来近似代替
p(x1,x2,x3,...,xn) p ( x 1 , x 2 , x 3 , . . . , x n ) <script type="math/tex" id="MathJax-Element-57">p(x_1,x_2,x_3,...,x_n)</script>
变分贝叶斯推断
在贝叶斯模型中,我们通常需要计算模型的后验概率 P(Z|X) P ( Z | X ) <script type="math/tex" id="MathJax-Element-382">P(Z|X)</script>,然而许多实际模型中,想要计算出 P(Z|X) P ( Z | X ) <script type="math/tex" id="MathJax-Element-383">P(Z|X)</script>通常是行不通的。利用平均场理论,我们通过找另一个模型 Q(Z)=∏iQ(zi) Q ( Z ) = ∏ i Q ( z i ) <script type="math/tex" id="MathJax-Element-384">Q(Z)=\prod _{i} Q(z_i)</script>来近似代替 P(Z|X) P ( Z | X ) <script type="math/tex" id="MathJax-Element-385">P(Z|X)</script>,这是变分贝叶斯推断的唯一假设!然后问题在于如何找出这样的模型 Q(Z) Q ( Z ) <script type="math/tex" id="MathJax-Element-386">Q(Z)</script>了。
熟悉信息理论的同学应该知道,想要衡量两个概率模型有多大差异,可以利用KL-Divergence。于是我们将问题转化为如何找到 Q(Z) Q ( Z ) <script type="math/tex" id="MathJax-Element-387">Q(Z)</script>使得
KL(Q||P)=∫Q(Z)logQ(Z)P(Z|X)dZ K L ( Q | | P ) = ∫ Q ( Z ) l o g Q ( Z ) P ( Z | X ) d Z
<script type="math/tex; mode=display" id="MathJax-Element-388">KL(Q||P) = \int Q(Z) log\frac{Q(Z)}{P(Z|X)} dZ</script>最小。我们知道KL散度是非对称的,那么为什么要用
KL(Q||P) K L ( Q | | P ) <script type="math/tex" id="MathJax-Element-389">KL(Q||P)</script>而不是
KL(P||Q) K L ( P | | Q ) <script type="math/tex" id="MathJax-Element-390">KL(P||Q)</script>呢?我们先看看PRML第十章里面的一副图:
图中绿色图是 P P <script type="math/tex" id="MathJax-Element-391">P</script>的分布,图(a)红色线利用通过最小化KL(Q||P)
<script type="math/tex" id="MathJax-Element-392">KL(Q||P)</script>也就是变分推断获得的结果,图(b)红色线是通过最小化 KL(P||Q) K L ( P | | Q ) <script type="math/tex" id="MathJax-Element-393">KL(P||Q)</script>的结果。如果选择最小化 KL(P||Q) K L ( P | | Q ) <script type="math/tex" id="MathJax-Element-394">KL(P||Q)</script>,那么其实是对应于另外一种近似框架——Expectation Propagation,超出本篇要讨论的,暂且搁置。
那么既然我们有了目标对象——最小化 KL(Q||P) K L ( Q | | P ) <script type="math/tex" id="MathJax-Element-395">KL(Q||P)</script>,接下来就是如何求得最小化时的 Q Q <script type="math/tex" id="MathJax-Element-396">Q</script>了。我们将公式稍微变换一下:
KL(Q||P)=∫Q(Z)logQ(Z)P(Z|X)dZ=−∫Q(Z)logP(Z|X)Q(Z)dZ=−∫Q(Z)logP(Z,X)Q(Z)P(X)dZ=∫Q(Z)[logQ(Z)+logP(X)]dZ−∫Q(Z)logP(Z,X)dZ=logP(X)+∫Q(Z)logQ(Z)dZ−∫Q(Z)logP(Z,X)dZ
<script type="math/tex; mode=display" id="MathJax-Element-397">KL(Q||P) = \int Q(Z) log\frac{Q(Z)}{P(Z|X)} dZ\\=-\int Q(Z) log\frac{P(Z|X)}{Q(Z)} dZ\\=-\int Q(Z) log\frac{P(Z,X)}{Q(Z)P(X)} dZ\\=\int Q(Z) [logQ(Z)+logP(X)] dZ-\int Q(Z) logP(Z,X)dZ\\=logP(X)+\int Q(Z) logQ(Z) dZ-\int Q(Z) logP(Z,X)dZ</script>
令
L(Q)=∫Q(Z)logP(Z,X)dZ−∫Q(Z)logQ(Z)dZ L ( Q ) = ∫ Q ( Z ) l o g P ( Z , X ) d Z − ∫ Q ( Z ) l o g Q ( Z ) d Z <script type="math/tex" id="MathJax-Element-398">L(Q) =\int Q(Z) logP(Z,X)dZ-\int Q(Z) logQ(Z) dZ </script>,
那么有:
logP(X)=KL(Q||P)+L(Q) l o g P ( X ) = K L ( Q | | P ) + L ( Q )
<script type="math/tex; mode=display" id="MathJax-Element-399">logP(X) = KL(Q||P) +L(Q) </script>我们目标是最小化
KL(Q||P) K L ( Q | | P ) <script type="math/tex" id="MathJax-Element-400">KL(Q||P)</script>,由于
logP(X) l o g P ( X ) <script type="math/tex" id="MathJax-Element-401">logP(X)</script> 不依赖于
Z Z <script type="math/tex" id="MathJax-Element-402">Z</script>的数据似然函数,可以当作是常数。那么为了最小化
KL(Q||P)
<script type="math/tex" id="MathJax-Element-403">KL(Q||P)</script>,反过来我们可以最大化
L(Q) L ( Q ) <script type="math/tex" id="MathJax-Element-404">L(Q)</script>,所以我们的目标可以转移为:
max L(Q) m a x L ( Q )
<script type="math/tex; mode=display" id="MathJax-Element-405">max \ L(Q)</script>因为
KL(Q||P)≥0 K L ( Q | | P ) ≥ 0 <script type="math/tex" id="MathJax-Element-406">KL(Q||P) \ge 0</script>那么有
logP(X)≥L(Q) l o g P ( X ) ≥ L ( Q )
<script type="math/tex; mode=display" id="MathJax-Element-407">logP(X) \ge L(Q)</script>所以
L(Q) L ( Q ) <script type="math/tex" id="MathJax-Element-408">L(Q)</script>可以看成是
logP(X) l o g P ( X ) <script type="math/tex" id="MathJax-Element-409">logP(X)</script>的下界,通常称为:ELOB(Evidence Lower Bound)。也就是我们通过最大化对数数据似然函数
logP(X) l o g P ( X ) <script type="math/tex" id="MathJax-Element-410">logP(X)</script>的下界来逼近对数似然函数的
logP(X) l o g P ( X ) <script type="math/tex" id="MathJax-Element-411">logP(X)</script>。
好,现在目标函数是:
L(Q)=∫Q(Z)logP(Z,X)dZ−∫Q(Z)logQ(Z)dZ L ( Q ) = ∫ Q ( Z ) l o g P ( Z , X ) d Z − ∫ Q ( Z ) l o g Q ( Z ) d Z
<script type="math/tex; mode=display" id="MathJax-Element-412">L(Q) =\int Q(Z) logP(Z,X)dZ-\int Q(Z) logQ(Z) dZ </script>
平均场定理的假设为:
Q(Z)=∏iQ(zi) Q ( Z ) = ∏ i Q ( z i )
<script type="math/tex; mode=display" id="MathJax-Element-413">Q(Z) = \prod_i Q(z_i)</script>先来等式右边后半部分:
∫Q(Z)logQ(Z)dZ=∫∏iQ(zi)log∏jQ(zj)dZ=∫∏iQ(zi)∑jlogQ(zj)dZ=∑j∫∏iQ(zi)logQ(zj)dZ=∑j∫Q(zj)logQ(zj)dzj∫∏i:i≠jQ(zi)dzi=∑j∫Q(zj)logQ(zj)dzj ∫ Q ( Z ) l o g Q ( Z ) d Z = ∫ ∏ i Q ( z i ) l o g ∏ j Q ( z j ) d Z = ∫ ∏ i Q ( z i ) ∑ j l o g Q ( z j ) d Z = ∑ j ∫ ∏ i Q ( z i ) l o g Q ( z j ) d Z = ∑ j ∫ Q ( z j ) l o g Q ( z j ) d z j ∫ ∏ i : i ≠ j Q ( z i ) d z i = ∑ j ∫ Q ( z j ) l o g Q ( z j ) d z j
<script type="math/tex; mode=display" id="MathJax-Element-414">\int Q(Z) logQ(Z) dZ = \int \prod_i Q(z_i)log\prod_j Q(z_j) dZ \\= \int \prod_i Q(z_i)\sum_j log Q(z_j) dZ \\= \sum_j \int \prod_i Q(z_i)log Q(z_j) dZ \\= \sum_j \int Q(z_j)log Q(z_j)dz_j \int \prod_{i: i\neq j} Q(z_i) dz_i \\= \sum_j \int Q(z_j)log Q(z_j)dz_j </script>神奇吧,仅仅有了平均场理论,我们便可将原本高维的复杂概率函数
Q(Z) Q ( Z ) <script type="math/tex" id="MathJax-Element-415">Q(Z)</script>拆分成单变量形式。此处的变换需要注意的记号是
dZ=∏idzi d Z = ∏ i d z i <script type="math/tex" id="MathJax-Element-416">dZ = \prod_i dz_i</script>,而不是代表对向量的微分,并且注意到
∫Q(zi)dzi=1 ∫ Q ( z i ) d z i = 1 <script type="math/tex" id="MathJax-Element-417">\int Q(z_i) dz_i = 1</script>,因此
∫∏i:i≠jQ(zi)dzi=∏i:i≠j∫Q(zi)dzi=1 ∫ ∏ i : i ≠ j Q ( z i ) d z i = ∏ i : i ≠ j ∫ Q ( z i ) d z i = 1 <script type="math/tex" id="MathJax-Element-418">\int \prod_{i: i\neq j} Q(z_i) dz_i=\prod_{i: i\neq j} \int Q(z_i) dz_i=1</script>。
再来看看另外一部分:
∫Q(Z)logP(Z,X)dZ=∫∏iQ(zi)logP(Z,X)dZ=∫Q(zj)(∏i:i≠jQ(zi)logP(Z,X)dzi)dzj=∫Q(zj)Ei≠j[logP(Z,X)]dzj=∫Q(zj)log{exp(Ei≠j[logP(Z,X)])}dzj=∫Q(zj)logexp(Ei≠j[logP(Z,X)])∫exp(Ei≠j[logP(Z,X)])dzj−C=∫Q(zj)logQ∗(zj)dzj−C ∫ Q ( Z ) l o g P ( Z , X ) d Z = ∫ ∏ i Q ( z i ) l o g P ( Z , X ) d Z = ∫ Q ( z j ) ( ∏ i : i ≠ j Q ( z i ) l o g P ( Z , X ) d z i ) d z j = ∫ Q ( z j ) E i ≠ j [ l o g P ( Z , X ) ] d z j = ∫ Q ( z j ) l o g { e x p ( E i ≠ j [ l o g P ( Z , X ) ] ) } d z j = ∫ Q ( z j ) l o g e x p ( E i ≠ j [ l o g P ( Z , X ) ] ) ∫ e x p ( E i ≠ j [ l o g P ( Z , X ) ] ) d z j − C = ∫ Q ( z j ) l o g Q ∗ ( z j ) d z j − C
<script type="math/tex; mode=display" id="MathJax-Element-419">\int Q(Z) logP(Z,X)dZ = \int \prod_i Q(z_i) logP(Z,X)dZ\\=\int Q(z_j) \bigg (\prod_{i:i\neq j} Q(z_i) logP(Z,X)dz_i \bigg )dz_j\\=\int Q(z_j) E_{i\neq j}[logP(Z,X)] dz_j \\= \int Q(z_j)log \big \{exp( E_{i\neq j}[logP(Z,X)])\big \} dz_j \\=\int Q(z_j)log\frac{exp(E_{i\neq j}[logP(Z,X)])}{\int exp(E_{i\neq j}[logP(Z,X)])} dz_j - C\\= \int Q(z_j)logQ^*(z_j) dz_j - C</script>
合并两部分,可得:
L(Q)=∫Q(zj)logQ∗(zj)dzj−∑j∫Q(zj)logQ(zj)dzj−C=∫Q(zj)logQ∗(zj)Q(zj)dzj−∑i:i≠j∫Q(zi)logQ(zi)dzi−C=−KL(Q(zj)||Q∗(zj))+∏i:i≠jH(Q(zi))−C L ( Q ) = ∫ Q ( z j ) l o g Q ∗ ( z j ) d z j − ∑ j ∫ Q ( z j ) l o g Q ( z j ) d z j − C = ∫ Q ( z j ) l o g Q ∗ ( z j ) Q ( z j ) d z j − ∑ i : i ≠ j ∫ Q ( z i ) l o g Q ( z i ) d z i − C = − K L ( Q ( z j ) | | Q ∗ ( z j ) ) + ∏ i : i ≠ j H ( Q ( z i ) ) − C
<script type="math/tex; mode=display" id="MathJax-Element-420">L(Q) = \int Q(z_j)logQ^*(z_j) dz_j - \sum_j\int Q(z_j)log Q(z_j)dz_j - C \\=\int Q(z_j)\frac{logQ^*(z_j)}{ Q(z_j)} dz_j -\sum_{i:i \neq j} \int Q(z_i)log Q(z_i)dz_i - C \\= -KL(Q(z_j)||Q^*(z_j)) + \prod_{i:i\neq j}H(Q(z_i)) - C</script>其中
H(Q(zi))=−∫Q(zi)logQ(zi)dzi H ( Q ( z i ) ) = − ∫ Q ( z i ) l o g Q ( z i ) d z i <script type="math/tex" id="MathJax-Element-421">H(Q(z_i)) = -\int Q(z_i) logQ(z_i)dz_i</script>为信息熵,又由于
KL(Q(zj)||Q∗(zj))≥0 K L ( Q ( z j ) | | Q ∗ ( z j ) ) ≥ 0
<script type="math/tex; mode=display" id="MathJax-Element-422">KL(Q(z_j)||Q^*(z_j)) \ge 0</script>并且
H(Q(zi))≥0 H ( Q ( z i ) ) ≥ 0
<script type="math/tex; mode=display" id="MathJax-Element-423">H(Q(z_i)) \ge 0</script>,那么要最大化
L(Q(Z)) L ( Q ( Z ) ) <script type="math/tex" id="MathJax-Element-424">L(Q(Z))</script>只需要令
−KL(Q(zj)||Q∗(zj))=0 − K L ( Q ( z j ) | | Q ∗ ( z j ) ) = 0
<script type="math/tex; mode=display" id="MathJax-Element-425">-KL(Q(z_j)||Q^*(z_j))=0</script>也就是只要使得
Q(zj)=Q∗(zj)=exp(Ei≠j[logP(Z,X)])normalize constant Q ( z j ) = Q ∗ ( z j ) = e x p ( E i ≠ j [ l o g P ( Z , X ) ] ) n o r m a l i z e c o n s t a n t
<script type="math/tex; mode=display" id="MathJax-Element-426">Q(z_j) = Q^*(z_j) = \frac{exp(E_{i\neq j}[logP(Z,X)])}{normalize \ constant}</script>,如果想直接用变分法求得最优解也是可以的,结合拉格朗日乘子法:
δδQ(zj){∫Q(zj)logQ∗(zj)dzj−∫Q(zj)logQ(zj)dzj+λi(∫iQ(zi)dzi−1)} δ δ Q ( z j ) { ∫ Q ( z j ) l o g Q ∗ ( z j ) d z j − ∫ Q ( z j ) l o g Q ( z j ) d z j + λ i ( ∫ i Q ( z i ) d z i − 1 ) }
<script type="math/tex; mode=display" id="MathJax-Element-427">\frac{\delta}{\delta Q(z_j)} \bigg \{ \int Q(z_j)logQ^*(z_j) dz_j - \int Q(z_j)log Q(z_j)dz_j + \lambda_i( \int_i Q(z_i)dz_i -1)\bigg \} </script>
经过一系列推导,最后也可以得到以上结果。
至此我们已经从理论上找到了变分贝叶斯推断的通用公式求法,如下算法便是:
- 循环直到收敛 :
- 对于每一个 Q(zj) Q ( z j ) <script type="math/tex" id="MathJax-Element-428">Q(z_j)</script>:
- 令 Q(zj)=Q∗(zj) Q ( z j ) = Q ∗ ( z j ) <script type="math/tex" id="MathJax-Element-429">Q(z_j) = Q^*(z_j)</script>
虽然从理论上推导了变分推断的框架算法,但是对于不同模型,我们必须手动推导 Q∗(zj) Q ∗ ( z j ) <script type="math/tex" id="MathJax-Element-430">Q^*(z_j)</script>,简要来说,推导变分贝叶斯模型一般分为四个步骤:
- 确定好研究模型各个参数的的共轭先验分布如果想做full bayes model
- 写出研究模型的联合分布 P(Z,X) P ( Z , X ) <script type="math/tex" id="MathJax-Element-431">P(Z,X)</script>
- 根据联合分布确定变分分布的形式 Q(Z) Q ( Z ) <script type="math/tex" id="MathJax-Element-432">Q(Z)</script>
- 对于每个变分因子 Q(zj) Q ( z j ) <script type="math/tex" id="MathJax-Element-433">Q(z_j)</script>求出 P(Z,X) P ( Z , X ) <script type="math/tex" id="MathJax-Element-434">P(Z,X)</script>关于不包含变量 zj z j <script type="math/tex" id="MathJax-Element-435">z_j</script>的数学期望,再规整化为概率分布
当然这个过程并不简单,对于实际模型,其推导一般比较繁冗复杂,很容易出错,想一看究竟的同学可以参考文末中给的链接。所以后来便有学者研究出更加一般更加自动化的基于概率图模型的算法框架——VMP(Varaitional Message Passing) 。如果模型是指数族的模型,都可以套用VMP自动得到算法求解:-)。 有兴趣的同学可以参考:《variational message passing》
参考文献
《A Tutorial on Variational Bayesian Inference》
《Pattern Recognition and Machine Learning》第十章
有兴趣的同学可以目睹LDA的变分推断过程:-)
《Latent Dirichlet Allocation》
所有评论(0)