0% found this document useful (0 votes)
8 views31 pages

08 VariationalInference

Uploaded by

Nicole Oo
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
8 views31 pages

08 VariationalInference

Uploaded by

Nicole Oo
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 31

Variational Inference

Hady W. Lauw
Photo by Nick Fewings on Unsplash

IS712 Machine Learning


Cross Entropy
• Cost of modeling some distribution 𝑝(𝑥) with a different distribution 𝑞(𝑥)
𝐻 𝑝, 𝑞 = & −𝑝 𝑥 log 𝑞 𝑥 𝑑𝑥

• Kullback-Leibler divergence or KL-divergence is the additional cost as


compared to using the original distribution 𝑝(𝑥)
𝐷!" 𝑝||𝑞 = & −𝑝 𝑥 log 𝑞 𝑥 𝑑𝑥 − & −𝑝 𝑥 log 𝑝 𝑥 𝑑𝑥

𝑞 𝑥
𝐷!" 𝑝||𝑞 = − & 𝑝 𝑥 log 𝑑𝑥
𝑝(𝑥)
– Always non-negative, i.e., 𝐷!" 𝑝||𝑞 ≥ 0
– Minimized when 𝑝(𝑥) and 𝑞(𝑥) are identical distributions

2
Kullback-Leibler Divergence (KL-Divergence)

https://en.wikipedia.org/wiki/File:KL-Gauss-Example.png 3
Latent Variable Model

• Suppose that we have a probabilistic model with observed variables 𝒙 and


hidden variables 𝒛, with its joint distribution parameterized by 𝜃
– Our goal is to maximize the likelihood function given by:
𝑝 𝒙 𝜃 = % 𝑝(𝒙, 𝒛|𝜃)
𝒛

– Suppose direct optimization of 𝑝 𝒙 𝜃 is difficult, but that of 𝑝(𝒙, 𝒛|𝜃) is easier

• Let 𝑞(𝒛) be a distribution over the latent variables. For any choice of
distribution 𝑞(𝒛), the following decomposition holds:
ln 𝑝 𝒙 𝜃 = ℒ 𝑞, 𝜃 + 𝐷!" (𝑞||𝑝)
𝑝(𝒙, 𝒛|𝜃)
ℒ 𝑞, 𝜃 = 9 𝑞 𝒛 ln
𝑞(𝒛)
𝒛
𝑝(𝒛|𝒙, 𝜃)
𝐷!" (𝑞| 𝑝 = − 9 𝑞 𝒛 ln
𝑞(𝒛)
𝒛

4
Decomposition

• We observe that:
𝑝 𝒙, 𝒛 𝜃 = 𝑝 𝒛 𝒙, 𝜃 𝑝 𝒙 𝜃

• Substitute this into ℒ 𝑞, 𝜃 , and we get:


𝑝(𝒙, 𝒛|𝜃)
ℒ 𝑞, 𝜃 = 9 𝑞 𝒛 ln
𝑞(𝒛)
𝒛
𝑝 𝒛 𝒙, 𝜃 𝑝(𝒙|𝜃)
= 9 𝑞 𝒛 ln
𝑞(𝒛)
𝒛
𝑝 𝒛 𝒙, 𝜃
= 9 𝑞 𝒛 ln + 9 𝑞 𝒛 ln 𝑝(𝒙|𝜃)
𝑞(𝒛)
𝒛 𝒛
= −𝐷!" (𝑞| 𝑝 + ln 𝑝(𝒙|𝜃)

• We get back the original decomposition:


ln 𝑝 𝒙 𝜃 = ℒ 𝑞, 𝜃 + 𝐷!" (𝑞||𝑝)
5
Illustration of the Decomposition

• Because KL-divergence is non-negative, ℒ 𝑞, 𝜃 is effectively a lower bound


on the log-likelihood ln 𝑝 𝒙 𝜃

6
Revisiting EM: E-Step

• When 𝑞 == 𝑝, we get back EM


• In the E-step, holding the old parameters 𝜃 "#$ fixed, we maximize ℒ 𝑞, 𝜃 "#$
w.r.t. 𝑞(𝒛)
– This happens when KL-divergence is 0, and thus 𝑞 𝒛 = 𝑝(𝒛|𝒙, 𝜃 $%& )

7
Revisiting EM: M-Step

• In the M-step, holding 𝑞 𝒛 fixed, we maximize ℒ 𝑞, 𝜃 w.r.t. parameters 𝜃


– This increases ℒ 𝑞, 𝜃 , which is a lower bound on the log-likelihood
– The new posterior 𝑝 𝒛|𝒙, 𝜃 $ ≠ 𝑝 𝒛|𝒙, 𝜃 $%& , so it creates a non-zero KL divergence

8
Variational Inference

• For EM, we assume that the evaluation of the posterior 𝑝 𝒛|𝒙, 𝜃 is tractable
• For many models, this evaluation may not be tractable
– E.g., when we introduce a prior that requires complex normalization, when the
dimensionality is too high
• Approximate inference by using a different distribution 𝑞 ≠ 𝑝 that is simpler
and more tractable
– The previous parameters 𝜃 are now absorbed into 𝒛

ln 𝑝(𝒙) = ℒ 𝑞 + 𝐷!" (𝑞||𝑝)


𝑝(𝒙, 𝒛)
ℒ 𝑞 = & 𝑞 𝒛 ln 𝑑𝒛
𝑞(𝒛)
𝑝(𝒛|𝒙)
𝐷!" (𝑞| 𝑝 = − & 𝑞 𝒛 ln 𝑑𝒛
𝑞(𝒛)

9
Jensen’s Inequality
• For a convex function (any minimum is a global minimum): 𝑓 E 𝑥 ≤ E 𝑓(𝑥)

• For a concave function: E 𝑓(𝑥) ≤ 𝑓 E 𝑥


10
Evidence Lower Bound (ELBO)
• Another view of ℒ 𝑞
• Evidence or likelihood:
ln 𝑝(𝒙) = ln & 𝑝 𝒙, 𝒛 𝑑𝒛

𝑝 𝒙, 𝒛
= ln & 𝑞(𝒛) 𝑑𝒛
𝑞(𝒛)
𝑝 𝒙, 𝒛
= ln E'
𝑞(𝒛)

• Based on Jensen’s inequality, evidence lower bound or ELBO:


𝑝(𝒙, 𝒛) 𝑝 𝒙, 𝒛 𝑝 𝒙, 𝒛
ℒ 𝑞 = & 𝑞 𝒛 ln 𝑑𝒛 = E' ln ≤ ln E'
𝑞(𝒛) 𝑞 𝒛 𝑞(𝒛)

11
Evidence Lower Bound (ELBO)
• To maximize the likelihood, we maximize the ELBO

• Interpretation of ℒ 𝑞
𝑝 𝒙, 𝒛
ℒ 𝑞 = E! ln = E' ln 𝑝 𝒙, 𝒛 − E' ln 𝑞 𝒛
𝑞 𝒛
– E' ln 𝑝 𝒙, 𝒛 is the expectation (under 𝑞) of the log of the joint probability
– 𝐻 𝑞 = −E' ln 𝑞 𝒛 is the entropy of the variational distribution

• As previously shown, maximizing ELBO is minimizing 𝐷./ (𝑞||𝑝)

12
Mean Field Variational Inference

• The variational distribution 𝑞 should be simple and tractable

• A frequent assumption is that the latent variables are independent

• Variational distribution factorizes


,
𝑞 𝒛 = 𝑞 𝑧& , 𝑧( , … , 𝑧) = @ 𝑞(𝑧* )
*+&

• Also possible to group some (dependent) variables together to form partitions

13
Mean Field Variational Inference
• With factorized variational distribution:
𝑝(𝒙, 𝒛)
ℒ 𝑞 = & 𝑞 𝒛 ln 𝑑𝒛
𝑞(𝒛)

= &@𝑞 𝑧* ln 𝑝(𝒙, 𝒛) − 9 ln 𝑞(𝑧* ) 𝑑𝒛


* *

• Dissect dependence on one factor:

ℒ 𝑞 = & 𝑞 𝑧* & ln 𝑝(𝒙, 𝒛) @ 𝑞 𝑧- 𝑑𝒛%* 𝑑𝑧* − & 𝑞 𝑧* ln 𝑞 𝑧* 𝑑𝑧* + const


-.*

= & 𝑞 𝑧* E-.* [ln 𝑝(𝒙, 𝒛)] 𝑑𝑧* − & 𝑞 𝑧* ln 𝑞 𝑧* 𝑑𝑧* + const

• Holding 𝑞 𝑧0 012 fixed, maximize ℒ 𝑞 w.r.t. 𝑞 𝑧2 in turn

14
General Steps

• Identify what distribution 𝑞 should be, e.g., Gaussian, Dirichlet

• Derive the ELBO

• Optimize ELBO via gradient ascent for each 𝑞(𝑧2 ) in turn

• Repeat till convergence

15
VARIATIONAL INFERENCE ON LDA

16
Latent Dirichlet Allocation
• LDA’s generative process for each document 𝑑0 :
– pick a topic distribution 𝜃- from a Dirichlet prior
𝜃- ∼ 𝐷𝑖𝑟(𝛼)
– for each of the 𝑁 words in 𝑑-
• pick a latent class 𝑧/ with probability 𝑝 𝑧/ |𝜃-
• pick a word 𝑤* with probability 𝑝 𝑤* 𝑧/ , 𝛽

• Document probability
1" !
𝑝 𝒘- 𝛼, 𝛽 = & 𝑝(𝜃- |𝛼) @ 9 𝑝 𝑤0 𝑧/ , 𝛽 𝑝(𝑧/ |𝜃- ) 𝑑𝜃-
0+& /+&
– 𝒘- are the observed words in document 𝑑- ∈ 𝐷$23$
– 𝑁- is the number of words in 𝑑-

17
Variational Inference
• Posterior distribution of the hidden variables
𝑝 𝜃, 𝒛, 𝒘|𝛼, 𝛽
𝑝 𝜃, 𝒛 𝒘, 𝛼, 𝛽 =
𝑝(𝒘|𝛼, 𝛽)
– drop index 𝑖 for a specific document
– intractable to compute because of coupling of 𝜃 and 𝛽

• Variational inference
– Simplified graphical model with fewer dependencies
1
𝑞 𝜃, 𝒛 𝛾, 𝝓 = 𝑞(𝜃|𝛾) @ 𝑞(𝑧0 |𝜙0 )
0+&
– 𝑞(𝜃|𝛾) for every document is Dirichlet over 𝐾 topics
– 𝑞(𝑧0 |𝜙0 ) for every token is Multinomial over 𝐾 topics

• Optimization objective
𝛾 ∗, 𝝓∗ = argmin(6,𝝓) 𝐷!" (𝑞(𝜃, 𝒛|𝛾, 𝝓)||𝑝 𝜃, 𝒛 𝒘, 𝛼, 𝛽 )
18
Optimization
• KL minimization is equivalent to ELBO maximization
ELBO = E! ln 𝑝 𝜃, 𝒛, 𝒘|𝛼, 𝛽 + 𝐻(𝑞) Legend:
• 𝐾: no of topics
• The expectation of the log of the joint probability: • 𝑁: no of tokens in document
E# ln 𝑝 𝜃, 𝒛, 𝒘|𝛼, 𝛽 = E# ln 𝑝 𝜃 |𝛼 + E# ln 𝑝 𝒛 |𝜃 + E# ln 𝑝 𝒘 |𝒛, 𝛽 • 𝑉: no of unique words in vocabulary
' ' ' • 𝛼$ : Dirichlet hyperparameter for topic 𝑘
'
= log Γ % 𝛼$ − % log Γ 𝛼$ + % 𝛼$ − 1 Ψ 𝛾$ − Ψ % 𝛾$ • 𝛾$ : Variational parameter for topic 𝑘
$%&
$%& $%& $%& • Γ: Gamma function
) ' ' ) ' + • Ψ: Digamma function
+ % % 𝜙($ Ψ 𝛾$ − Ψ % 𝛾$ + % % % 𝜙($ 𝑤(* log 𝛽$* • 𝜙($ : variational parameter for
(%& $%& $%& (%& $%& *%& assignment of token 𝑛 to topic 𝑘
• 𝛾$ : variational parameter for topic 𝑘 in
• The entropy of the variational distribution: this document
) '
• 𝑤(* : token 𝑛 having word form 𝑣
𝐻 𝑞 = 𝐻 𝛾 + % % 𝐻(𝜙($ ) • 𝛽$* : probability of topic 𝑘 generating
(%& $%& word form 𝑣
' ' '
'
= − log Γ % 𝛾$ + % log Γ 𝛾$ − % 𝛾$ − 1 Ψ 𝛾$ − Ψ % 𝛾$
$%&
$%& $%& $%&
) ' For full derivation refer to
− % % 𝜙($ log 𝜙($ https://youtu.be/2pEkWk-LHmU
(%& $%&
19
Variational Inference Algorithm for LDA

• Randomly initialize variational parameters


• For each iteration:
– For each document 𝑖 (index dropped), update 𝛾. For each token in document, update 𝜙.
1
𝛾/ = 𝛼/ + 9 𝜙0/
0+&
!
𝜙0/ ∝ 𝛽/: exp Ψ 𝛾/ − Ψ 9 𝛾/ Normalize ∑%
"#$ 𝜙&" = 1
/+&
– For corpus, update 𝛽
1"
:
𝛽/: ∝ 9 9 𝜙-0/ ⋅ 𝑤-0 Normalize ∑('#$ 𝛽"' = 1
- 0+&
– Compute ℒ to assess convergence
• Return expectation of variational parameters for solution to latent variables
20
VARIATIONAL AUTOENCODER (VAE)

21
Variational Autoencoder (VAE)

• Designed not only for content encoding, but also content generation

• Key idea:
– regularize the latent space of the encoding so that nearby points in the latent space will
generate similar decodings

https://towardsdatascience.com/understanding-variational-autoencoders-vaes-f70510919f73
22
Generative Model
@
• Let us consider a dataset 𝒙0 0?$
– Data is generated by a random process involving latent variables 𝒛

• Generative process:
– For each data point 𝑖 = 1, … , 𝑁:
• Sample latent variable 𝒛- from a parameterized distribution 𝑝(𝒛|𝜃)
• Sample observation 𝒙- from a parameterized distribution 𝑝(𝒙|𝒛, 𝜃)

• Suppose the integral of the marginal likelihood 𝑝 𝒙 𝜃 = ∫ 𝑝 𝒛 𝜃 𝑝 𝒙 𝒛, 𝜃 𝑑𝒛


A 𝒙|𝒛,B A(𝒛|B)
is intractable, correspondingly so is the true posterior 𝑝 𝒛 𝒙, 𝜃 =
A(𝒙|B)

23
Variational Inference

• Introduce 𝑞(𝒛|𝒙, 𝜙) as approximation to 𝑝(𝒛|𝒙, 𝜃)

• Interpretation:
– 𝒛 is the encoding of input 𝒙
– 𝑞(𝒛|𝒙, 𝜙) is the encoder
– 𝑝(𝒙|𝒛, 𝜃) is the decoder

• Objective is to maximize ELBO:


𝑝 𝒙, 𝒛|𝜃
ℒ 𝑞 = E' ln
𝑞 𝒛|𝒙, 𝜙

• Equivalently, to minimize the following KL-divergence:


𝐷!" (𝑞(𝒛|𝒙, 𝜙)||𝑝(𝒛|𝒙, 𝜃))

24
Decoder and Encoder

Decoder Encoder
• Let 𝑝(𝒛|𝜃) be factorized multivariate • Let 𝑞(𝒛|𝒙, 𝜙) be factorized multivariate
Gaussian Gaussian
ln 𝑝 𝒛 𝜃 = ln 𝒩(𝟎, 𝐈) ln 𝑞 𝒛 𝒙, 𝜙 = ln 𝒩(𝒙; 𝝆, 𝝎𝟐 𝐈)

• Let 𝑝(𝒙|𝒛, 𝜃) be factorized multivariate – 𝝆 and 𝝎𝟐 are outputs of multi-layer


Gaussian perceptrons parameterized by 𝜙 (network
weights) on input 𝒙
ln 𝑝 𝒙 𝒛, 𝜃 = ln 𝒩(𝒙; 𝝁, 𝝈𝟐 𝐈)

– 𝝁 and 𝝈𝟐 are outputs of multi-layer


perceptrons parameterized by 𝜃 (network
weights)
𝝁 = 𝑾* 𝒉 + 𝒃*
𝝈𝟐 = 𝑾+ 𝒉 + 𝒃+
𝒉 = tanh(𝑾$ 𝒛 + 𝒃$ )
– latent representation 𝒉 shared by 𝝁 and 𝝈𝟐
25
Illustration
• Encoded distributions are Normal distributions (parameterized by
mean and variance)

https://towardsdatascience.com/intuitively-understanding-variational-autoencoders-
1bfe67eb5daf
https://www.cl.cam.ac.uk/~pv273/slides/UCLSlides.pdf
26
Optimization (for Gaussian)
• Objective is to maximize ELBO:
𝑝 𝒙, 𝒛|𝜃 𝑝 𝒙|𝒛, 𝜃 𝑝(𝒛|𝜃)
ℒ 𝑞 = E! ln = E! ln = E! ln 𝑝 𝒙|𝒛, 𝜃 − 𝐷%, 𝑞 𝒛|𝒙, 𝜙 ||𝑝(𝒛|𝜃)
𝑞 𝒛|𝒙, 𝜙 𝑞 𝒛|𝒙, 𝜙

• The first term E' ln 𝑝 𝒙|𝒛, 𝜃 is the expected (negative) reconstruction error
– for each point, estimated by taking 𝐿 samples of 𝑞(𝒛|𝒙, 𝜙) via Monte Carlo expectation estimation
,
1
E! ln 𝑝 𝒙|𝒛, 𝜃 ≈ R ln 𝑝(𝒙|𝒛(-) )
𝐿
-#$
– requires “reparameterization trick” to make backpropagation and gradient descent possible

• The second term −𝐷!" 𝑞 𝒛|𝒙, 𝜙 ||𝑝(𝒛|𝜃) regularizes 𝑞(𝒛|𝒙, 𝜙) to the prior 𝒩 𝟎, 𝐈
– Can be derived analytically (no need sampling)
1
1
−𝐷%, 𝑞 𝒛|𝒙, 𝜙 ||𝑝(𝒛|𝜃) = R 1 + ln( 𝜔0 + ) − 𝜌0 +
− 𝜔0 +
2
0#$
– Here 𝑴 is the dimensionality of 𝒛 (size of latent variables)
– 𝜔0 and 𝜌0 are neural network functions of 𝒙 and 𝜙
27
Visualization of Learned Data Manifold

28
Different Dimensionalities of Latent Space

29
Conclusion
• Variational Inference:
– Approximating inference of a latent variable model with simpler, more tractable inference
– Maximizes evidence lower bound (ELBO) or minimizes KL-divergence between
variational distribution and the posterior distribution
– Mean field inference is one simplification where the variational distribution factorizes
• Latent Dirichlet Allocation:
– Variational inference is an alternative to Gibbs sampling
– Uses a variational distribution that decouples some dependent variables
– Varitional distribution factorizes, thus the learning algorithm is parallelizable
• Variational Auto-Encoder:
– An auto-encoder that is also a generative model
– Encoding is not deterministic, but instead is a distribution
– Lends itself naturally to variational inference
– Relies on neural network to approximate functions
– Widely applicable, e.g., collaborative filtering, topic modeling
30
References
• Bishop, C. M. (2006). Pattern recognition and machine learning. Springer.
– Chapter 10.1 (Variational Inference)

• Blei, D. Variational Inference.


– https://www.cs.princeton.edu/courses/archive/fall11/cos597C/lectures/variational-
inference-i.pdf

• Blei, D. M., Ng, A. Y., & Jordan, M. I. (2003). Latent dirichlet allocation.
Journal of machine Learning research, 3(Jan), 993-1022.
– https://www.jmlr.org/papers/volume3/blei03a/blei03a.pdf

• Kingma, D. P., & Welling, M. (2014). Auto-Encoding Variational Bayes.


International Conference on Learning Representations (ICLR).
– https://arxiv.org/pdf/1312.6114.pdf
31

You might also like

pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy