08 VariationalInference
08 VariationalInference
Hady W. Lauw
Photo by Nick Fewings on Unsplash
𝑞 𝑥
𝐷!" 𝑝||𝑞 = − & 𝑝 𝑥 log 𝑑𝑥
𝑝(𝑥)
– Always non-negative, i.e., 𝐷!" 𝑝||𝑞 ≥ 0
– Minimized when 𝑝(𝑥) and 𝑞(𝑥) are identical distributions
2
Kullback-Leibler Divergence (KL-Divergence)
https://en.wikipedia.org/wiki/File:KL-Gauss-Example.png 3
Latent Variable Model
• Let 𝑞(𝒛) be a distribution over the latent variables. For any choice of
distribution 𝑞(𝒛), the following decomposition holds:
ln 𝑝 𝒙 𝜃 = ℒ 𝑞, 𝜃 + 𝐷!" (𝑞||𝑝)
𝑝(𝒙, 𝒛|𝜃)
ℒ 𝑞, 𝜃 = 9 𝑞 𝒛 ln
𝑞(𝒛)
𝒛
𝑝(𝒛|𝒙, 𝜃)
𝐷!" (𝑞| 𝑝 = − 9 𝑞 𝒛 ln
𝑞(𝒛)
𝒛
4
Decomposition
• We observe that:
𝑝 𝒙, 𝒛 𝜃 = 𝑝 𝒛 𝒙, 𝜃 𝑝 𝒙 𝜃
6
Revisiting EM: E-Step
7
Revisiting EM: M-Step
8
Variational Inference
• For EM, we assume that the evaluation of the posterior 𝑝 𝒛|𝒙, 𝜃 is tractable
• For many models, this evaluation may not be tractable
– E.g., when we introduce a prior that requires complex normalization, when the
dimensionality is too high
• Approximate inference by using a different distribution 𝑞 ≠ 𝑝 that is simpler
and more tractable
– The previous parameters 𝜃 are now absorbed into 𝒛
9
Jensen’s Inequality
• For a convex function (any minimum is a global minimum): 𝑓 E 𝑥 ≤ E 𝑓(𝑥)
𝑝 𝒙, 𝒛
= ln & 𝑞(𝒛) 𝑑𝒛
𝑞(𝒛)
𝑝 𝒙, 𝒛
= ln E'
𝑞(𝒛)
11
Evidence Lower Bound (ELBO)
• To maximize the likelihood, we maximize the ELBO
• Interpretation of ℒ 𝑞
𝑝 𝒙, 𝒛
ℒ 𝑞 = E! ln = E' ln 𝑝 𝒙, 𝒛 − E' ln 𝑞 𝒛
𝑞 𝒛
– E' ln 𝑝 𝒙, 𝒛 is the expectation (under 𝑞) of the log of the joint probability
– 𝐻 𝑞 = −E' ln 𝑞 𝒛 is the entropy of the variational distribution
12
Mean Field Variational Inference
13
Mean Field Variational Inference
• With factorized variational distribution:
𝑝(𝒙, 𝒛)
ℒ 𝑞 = & 𝑞 𝒛 ln 𝑑𝒛
𝑞(𝒛)
14
General Steps
15
VARIATIONAL INFERENCE ON LDA
16
Latent Dirichlet Allocation
• LDA’s generative process for each document 𝑑0 :
– pick a topic distribution 𝜃- from a Dirichlet prior
𝜃- ∼ 𝐷𝑖𝑟(𝛼)
– for each of the 𝑁 words in 𝑑-
• pick a latent class 𝑧/ with probability 𝑝 𝑧/ |𝜃-
• pick a word 𝑤* with probability 𝑝 𝑤* 𝑧/ , 𝛽
• Document probability
1" !
𝑝 𝒘- 𝛼, 𝛽 = & 𝑝(𝜃- |𝛼) @ 9 𝑝 𝑤0 𝑧/ , 𝛽 𝑝(𝑧/ |𝜃- ) 𝑑𝜃-
0+& /+&
– 𝒘- are the observed words in document 𝑑- ∈ 𝐷$23$
– 𝑁- is the number of words in 𝑑-
17
Variational Inference
• Posterior distribution of the hidden variables
𝑝 𝜃, 𝒛, 𝒘|𝛼, 𝛽
𝑝 𝜃, 𝒛 𝒘, 𝛼, 𝛽 =
𝑝(𝒘|𝛼, 𝛽)
– drop index 𝑖 for a specific document
– intractable to compute because of coupling of 𝜃 and 𝛽
• Variational inference
– Simplified graphical model with fewer dependencies
1
𝑞 𝜃, 𝒛 𝛾, 𝝓 = 𝑞(𝜃|𝛾) @ 𝑞(𝑧0 |𝜙0 )
0+&
– 𝑞(𝜃|𝛾) for every document is Dirichlet over 𝐾 topics
– 𝑞(𝑧0 |𝜙0 ) for every token is Multinomial over 𝐾 topics
• Optimization objective
𝛾 ∗, 𝝓∗ = argmin(6,𝝓) 𝐷!" (𝑞(𝜃, 𝒛|𝛾, 𝝓)||𝑝 𝜃, 𝒛 𝒘, 𝛼, 𝛽 )
18
Optimization
• KL minimization is equivalent to ELBO maximization
ELBO = E! ln 𝑝 𝜃, 𝒛, 𝒘|𝛼, 𝛽 + 𝐻(𝑞) Legend:
• 𝐾: no of topics
• The expectation of the log of the joint probability: • 𝑁: no of tokens in document
E# ln 𝑝 𝜃, 𝒛, 𝒘|𝛼, 𝛽 = E# ln 𝑝 𝜃 |𝛼 + E# ln 𝑝 𝒛 |𝜃 + E# ln 𝑝 𝒘 |𝒛, 𝛽 • 𝑉: no of unique words in vocabulary
' ' ' • 𝛼$ : Dirichlet hyperparameter for topic 𝑘
'
= log Γ % 𝛼$ − % log Γ 𝛼$ + % 𝛼$ − 1 Ψ 𝛾$ − Ψ % 𝛾$ • 𝛾$ : Variational parameter for topic 𝑘
$%&
$%& $%& $%& • Γ: Gamma function
) ' ' ) ' + • Ψ: Digamma function
+ % % 𝜙($ Ψ 𝛾$ − Ψ % 𝛾$ + % % % 𝜙($ 𝑤(* log 𝛽$* • 𝜙($ : variational parameter for
(%& $%& $%& (%& $%& *%& assignment of token 𝑛 to topic 𝑘
• 𝛾$ : variational parameter for topic 𝑘 in
• The entropy of the variational distribution: this document
) '
• 𝑤(* : token 𝑛 having word form 𝑣
𝐻 𝑞 = 𝐻 𝛾 + % % 𝐻(𝜙($ ) • 𝛽$* : probability of topic 𝑘 generating
(%& $%& word form 𝑣
' ' '
'
= − log Γ % 𝛾$ + % log Γ 𝛾$ − % 𝛾$ − 1 Ψ 𝛾$ − Ψ % 𝛾$
$%&
$%& $%& $%&
) ' For full derivation refer to
− % % 𝜙($ log 𝜙($ https://youtu.be/2pEkWk-LHmU
(%& $%&
19
Variational Inference Algorithm for LDA
21
Variational Autoencoder (VAE)
• Designed not only for content encoding, but also content generation
• Key idea:
– regularize the latent space of the encoding so that nearby points in the latent space will
generate similar decodings
https://towardsdatascience.com/understanding-variational-autoencoders-vaes-f70510919f73
22
Generative Model
@
• Let us consider a dataset 𝒙0 0?$
– Data is generated by a random process involving latent variables 𝒛
• Generative process:
– For each data point 𝑖 = 1, … , 𝑁:
• Sample latent variable 𝒛- from a parameterized distribution 𝑝(𝒛|𝜃)
• Sample observation 𝒙- from a parameterized distribution 𝑝(𝒙|𝒛, 𝜃)
23
Variational Inference
• Interpretation:
– 𝒛 is the encoding of input 𝒙
– 𝑞(𝒛|𝒙, 𝜙) is the encoder
– 𝑝(𝒙|𝒛, 𝜃) is the decoder
24
Decoder and Encoder
Decoder Encoder
• Let 𝑝(𝒛|𝜃) be factorized multivariate • Let 𝑞(𝒛|𝒙, 𝜙) be factorized multivariate
Gaussian Gaussian
ln 𝑝 𝒛 𝜃 = ln 𝒩(𝟎, 𝐈) ln 𝑞 𝒛 𝒙, 𝜙 = ln 𝒩(𝒙; 𝝆, 𝝎𝟐 𝐈)
https://towardsdatascience.com/intuitively-understanding-variational-autoencoders-
1bfe67eb5daf
https://www.cl.cam.ac.uk/~pv273/slides/UCLSlides.pdf
26
Optimization (for Gaussian)
• Objective is to maximize ELBO:
𝑝 𝒙, 𝒛|𝜃 𝑝 𝒙|𝒛, 𝜃 𝑝(𝒛|𝜃)
ℒ 𝑞 = E! ln = E! ln = E! ln 𝑝 𝒙|𝒛, 𝜃 − 𝐷%, 𝑞 𝒛|𝒙, 𝜙 ||𝑝(𝒛|𝜃)
𝑞 𝒛|𝒙, 𝜙 𝑞 𝒛|𝒙, 𝜙
• The first term E' ln 𝑝 𝒙|𝒛, 𝜃 is the expected (negative) reconstruction error
– for each point, estimated by taking 𝐿 samples of 𝑞(𝒛|𝒙, 𝜙) via Monte Carlo expectation estimation
,
1
E! ln 𝑝 𝒙|𝒛, 𝜃 ≈ R ln 𝑝(𝒙|𝒛(-) )
𝐿
-#$
– requires “reparameterization trick” to make backpropagation and gradient descent possible
• The second term −𝐷!" 𝑞 𝒛|𝒙, 𝜙 ||𝑝(𝒛|𝜃) regularizes 𝑞(𝒛|𝒙, 𝜙) to the prior 𝒩 𝟎, 𝐈
– Can be derived analytically (no need sampling)
1
1
−𝐷%, 𝑞 𝒛|𝒙, 𝜙 ||𝑝(𝒛|𝜃) = R 1 + ln( 𝜔0 + ) − 𝜌0 +
− 𝜔0 +
2
0#$
– Here 𝑴 is the dimensionality of 𝒛 (size of latent variables)
– 𝜔0 and 𝜌0 are neural network functions of 𝒙 and 𝜙
27
Visualization of Learned Data Manifold
28
Different Dimensionalities of Latent Space
29
Conclusion
• Variational Inference:
– Approximating inference of a latent variable model with simpler, more tractable inference
– Maximizes evidence lower bound (ELBO) or minimizes KL-divergence between
variational distribution and the posterior distribution
– Mean field inference is one simplification where the variational distribution factorizes
• Latent Dirichlet Allocation:
– Variational inference is an alternative to Gibbs sampling
– Uses a variational distribution that decouples some dependent variables
– Varitional distribution factorizes, thus the learning algorithm is parallelizable
• Variational Auto-Encoder:
– An auto-encoder that is also a generative model
– Encoding is not deterministic, but instead is a distribution
– Lends itself naturally to variational inference
– Relies on neural network to approximate functions
– Widely applicable, e.g., collaborative filtering, topic modeling
30
References
• Bishop, C. M. (2006). Pattern recognition and machine learning. Springer.
– Chapter 10.1 (Variational Inference)
• Blei, D. M., Ng, A. Y., & Jordan, M. I. (2003). Latent dirichlet allocation.
Journal of machine Learning research, 3(Jan), 993-1022.
– https://www.jmlr.org/papers/volume3/blei03a/blei03a.pdf