Partitioned Variational Inference: A Unified Framework Encompassing Federated and Continual Learning
Partitioned Variational Inference: A Unified Framework Encompassing Federated and Continual Learning
Abstract
Variational inference (VI) has become the method of choice for fitting many modern probabilistic
models. However, practitioners are faced with a fragmented literature that offers a bewildering
array of algorithmic options. First, the variational family. Second, the granularity of the updates
e.g. whether the updates are local to each data point and employ message passing or global. Third,
the method of optimization (bespoke or blackbox, closed-form or stochastic updates, etc.). This
paper presents a new framework, termed Partitioned Variational Inference (PVI), that explicitly
acknowledges these algorithmic dimensions of VI, unifies disparate literature, and provides guidance
on usage. Crucially, the proposed PVI framework allows us to identify new ways of performing
VI that are ideally suited to challenging learning scenarios including federated learning (where
distributed computing is leveraged to process non-centralized data) and continual learning (where
new data and tasks arrive over time and must be accommodated quickly). We showcase these new
capabilities by developing communication-efficient federated training of Bayesian neural networks
and continual learning for Gaussian process models with private pseudo-points. The new methods
significantly outperform the state-of-the-art, whilst being almost as straightforward to implement as
standard VI.
1 Introduction
Variational methods recast approximate inference as an optimization problem, thereby enabling advances
in optimization to be leveraged for inference. VI has enabled approaches including natural gradient
methods, mirror-descent, trust region and stochastic (mini-batch) optimization to be tapped in this way.
The approach has been successful, with VI methods often lying on the efficient frontier of approximate
inference’s speed-accuracy trade-off. VI has consequently become one of the most popular varieties of
approximate inference. For example, it is now a standard approach for Gaussian process models [Titsias,
2009], latent topic models [Blei et al., 2003], and deep generative models [Kingma and Welling, 2014].
Deployment of VI requires the practitioner to make three fundamental choices. First, the form of the
approximate family which ranges from simple mean-field or factorized distributions, through unfactorized
exponential families to complex non-exponential family distributions. Second, the granularity of
variational inference which includes, on the one hand, approaches based on the global variational
free-energy, and on the other those that consider a single data point at a time and employ local
message passing. Third, the form of the variational updates which encompasses the optimization method
employed for maximizing the global variational free-energy or the form of the message passing updates.
A large body of work has investigated how the choice of approximating family affects the accuracy
of VI [MacKay, 2003, Wang and Titterington, 2004, Turner and Sahani, 2011] and how additional
approximations can enable VI to support more complex approximate families [Jaakkola and Jordan,
1
1998, Rezende and Mohamed, 2015, Salimans et al., 2015, Ranganath et al., 2016, Mescheder et al., 2017].
This is a fundamental question, but it is orthogonal to the focus of the current paper. Instead, we focus
on the second two choices. The granularity of variational inference is an important algorithmic dimension.
Whilst global variational inference has more theoretical guarantees and is arguably simpler to implement,
local variational inference offers unique opportunities for online or continual learning (e.g. allowing ‘old’
data to be sporadically revisited) and distributed computing (e.g. supporting asynchronous lock-free
updates). The form of the updates is equally important with a burgeoning set of alternatives. For global
VI these including gradient ascent, natural gradient and mirror descent, approximate second-order
methods, stochastic versions thereof, collapsed VI and fixed-point updates to name but a few. For local
VI, there has been less exploration of the options, but damping in natural and moment space is often
employed.
The goal of this paper is to develop a unifying framework, termed Partitioned Variational Infer-
ence (PVI), that explicitly acknowledges that the granularity and the optimization method are two
fundamental algorithmic dimensions of VI. The new framework 1. generalizes and extends current
theoretical results in this area, 2. reveals the relationship between a large number of existing schemes,
and 3. identifies opportunities for innovation, a selection of which are demonstrated in experiments.
We briefly summarize the contributions of this paper, focusing on the unified viewpoint and novel
algorithmic extensions to support federated and continual learning.
1.1 Unification
The main theoretical contributions of the paper, described in sections 2 to 4, are: to develop Partitioned
Variational Inference; clean up, generalize and derive new supporting theory (including PVI fixed-point
optimization, mini-batch approximation, hyperparameter learning); and show that PVI subsumes
standard global variational inference, (local) variational message passing, and other well-established
approaches. In addition, we also show in section 4 that damped fixed-point optimization and natural
gradient methods applied to PVI are equivalent to variationally-limited power EP.
In section 4 PVI is used to connect a large literature that has become fragmented with separated
strands of related, but mutually uncited work. More specifically we unify work on: online VI [Ghahramani
and Attias, 2000, Sato, 2001, Broderick et al., 2013, Bui et al., 2017b, Nguyen et al., 2018]; global
VI [Sato, 2001, Hensman et al., 2012, Hoffman et al., 2013, Salimans and Knowles, 2013, Sheth and
Khardon, 2016a, Sheth et al., 2015, Sheth and Khardon, 2016b]; local VI [Knowles and Minka, 2011,
Wand, 2014, Khan and Lin, 2018]; power EP and related algorithms [Minka, 2001, 2004, Li et al., 2015,
Hasenclever et al., 2017, Gelman et al., 2014]; and stochastic mini-batch variants of these algorithms
[Hoffman et al., 2013, Li et al., 2015, Khan and Lin, 2018]. Figures 2 and 3 and table 1 present a
summary of these relationships in the context of PVI.
• modern data sets can often be distributed inhomogeneously and unevenly across many machines,
for examples, mobile devices can contain many images which can be used for training a classification
model, but accessing such information is often restricted and privacy-sensitive;
2
• the inference or prediction step is often needed in an any-time fashion at each machine, i.e. each
machine needs to have access to a high-quality model to make predictions without having to send
data to a remote server.
These requirements are often not satisfied in the traditional training pipelines, many of which require
data to be stored in a single machine, or in a data center where it is typically distributed among
many machines in a homogeneous and balanced fashion [see e.g. Dean et al., 2012, Zhang et al., 2015,
Chen et al., 2016]. Federated learning attempts to bridge this gap by tackling the aforementioned
constraints. Additionally, this type of learning is arguably less privacy-sensitive as compared to
centralized learning approaches, as it does not require local data to be collected and sent to a central
server. It can also be further improved by employing encrypted aggregation steps [Bonawitz et al.,
2017] or differentially-private mechanisms [Dwork and Roth, 2014].
Distributed inference is also an active research area in the Bayesian statistics and machine learn-
ing literature. For example, parallel Markov chain Monte Carlo approaches typically run multiple
independent Markov chains on different partitions of the data set, but require heuristics to aggregate,
reweight and average the samples at test time [see e.g. Wang and Dunson, 2013, Scott et al., 2016].
The closest to our work is the distributed EP algorithms of Gelman et al. [2014] and Hasenclever et al.
[2017], which employ (approximate) MCMC for data partitions and EP for communication between
workers. However, it is not clear these distributed approaches will work well in the federated settings
described above. In section 5, we demonstrate that PVI can naturally and flexibly address the above
challenges, and thus be used for federated learning with efficient synchronous or lock-free asynchronous
communication. The proposed approach can be combined with recent advances in Monte Carlo VI for
neural networks, enabling fast and communication-efficient training of Bayesian neural networks on
non-iid federated data. We provide an extensive experiment comparing to alternative approaches in
section 7.
3
[2017a], allowing principled handling of hyperparameters and private pseudo-points for new data. The
new technique is shown to be superior to alternative online learning approaches on various toy and
real-world data sets in section 7. We also show in section 5 that continual learning can be reframed as
a special case of federated learning.
M M
Y 1 Y
q(θ) = p(θ) tm (θ) ≈ p(θ) p(y m |θ) = p(θ|y), (1)
Z
m=1 m=1
where Z is the normalizing constant of the true posterior, or marginal likelihood. The approximate
likelihood tm (θ) will be refined by PVI to approximate the effect the likelihood term p(y m |θ) has on the
posterior. Note that the form of q(θ) in (1) is similar to that employed by the expectation propagation
algorithm [Minka, 2001], but with two differences. First, the approximate posterior is not restricted to
lie in the exponential family, as is typically the case for EP. Second, the approximate posterior does
not include a normalizing constant. Instead, the PVI algorithm will automatically ensure that the
product of the prior and approximate likelihood factors in (1) is a normalized distribution. We will
show that PVI will return an approximation to the marginal likelihood log Z = log p(y) in addition to
the approximation of the posterior.
Algorithm 1 details the PVI algorithm. At each iteration i, we select an approximate likelihood
(i−1)
to refine according to a schedule bi ∈ {1 . . . M }. The approximate likelihood tbi (θ) obtained from
the previous iteration will be refined and the corresponding data-group is denoted y bi . The refinement
proceeds in two steps. First, we refine the approximate posterior using the local (negative) variational
free energy q (i) (θ) = argmaxq(θ)∈Q F (i) (q(θ)) where the optimization is over a tractable family Q and
4
free-energies is equal to the global variational free-energy. The following properties apply for general
q(θ), and are not limited to the exponential family.1
Property 1 Maximizing the local free-energy F (i) (q(θ)) is equivalent to the KL optimization
q (i) (θ) = argmin KL q(θ) k pb(i) (θ) , (3)
q(θ)∈Q
The proof is straightforward (see A.1). The tilted distribution can be justified as a sensible target
(i−1)
as it removes the approximate likelihood tbi (θ) from the current approximate posterior and replaces
it with the true likelihood p(y bi |θ). In this way, the tilted distribution comprises one true likelihood,
M − 1 approximate likelihoods and the prior. The KL optimization then ensures the new posterior
better approximates the true likelihood’s effect, in the context of the approximate likelihoods and the
prior.
QM (i)
Property 2 At the end of each iteration i = 0, 1, . . ., q (i) (θ) = p(θ) m=1 tm (θ).
Again the proof is simple (see A.2), but it relies on PVI initializing the approximate likelihood factors
to unity so that q (0) (θ) = p(θ).
These results are more complex to show, but can be derived by computing the derivative and Hessian
of the global free-energy and substituting into these expressions the derivatives and Hessians of the
local free-energies (see A.3). The fact that the fixed point of PVI recovers a global VI solution (both
the optimal q(θ) and the global free-energy at this optimum) is the main theoretical justification for
employing PVI. However, we do not believe that there is a Lyapunov function for PVI, indicating that
it may oscillate or diverge in general.
Having laid out the general framework for PVI, what remains to be decided is the method used for
optimizing the local free-energies. In a moment we consider three choices: analytic updates, off-the-shelf
optimization methods and fixed-point iterations, as well as discussing how stochastic approximations
can be combined with these approaches. Before turning to these choices, we compare and contrast the
algorithmic benefits of the local and global approaches to VI in different settings. This discussion will
help shape the development of the optimization choices which follows.
2.1 When should a local VI approach be employed rather than a global one?
We will describe in section 4 how the PVI framework unifies a large body of existing literature, thereby
providing a useful conceptual scaffold for understanding the relationship between algorithms. However,
1
However, we will only consider exponential family approximations in the experiments in section 7.
5
Algorithm 1 Partitioned Variational Inference
Input: data partition {y 1 , . . . , y M }, prior p(θ)
Initialize:
t(0)
m (θ) := 1 for all m = 1, 2, . . . , M.
q (0) (θ) := p(θ).
end for
Figure 1: Steps of the PVI algorithm when being used for continual learning [left] and federated learning
[right].
it is important to ask: What algorithmic and computation benefits, if any, arise from considering a set
of local free-energy updates, rather than a single global approximation (possibly leveraging stochastic
mini-batch approximation)?
In a nutshell, we will show that if the data set is fixed before inference is performed (batch learning)
or arrives in a simple online iid way (simple online learning), and distributed computation is not available,
6
then global VI will typically be simpler to implement, require less memory, and faster to converge than
more local versions of PVI (the case of scaling collapsed bounds being a possible exception). However,
if the conditions above are not met, the local versions of PVI will be appropriate. We will now unpack
important examples of this sort.
The PVI approach is ideally suited to the distributed setting, with simple distributed variants
allowing asynchronous distributed updates. One simple approach, similar to that of Hasenclever et al.
[2017], uses M workers that are each allocated a data group y m . The workers store and refine the
associated approximate likelihood tm (θ). A server maintains and updates the approximate posterior and
communicates it to the workers. An idle worker receives the current posterior from the server, optimizes
(new) (old)
the local free-energy, computes the change in the local approximate likelihood ∆m (θ) = tm (θ)/tm (θ),
sends this to the server, and repeats. The local workers do not change q(θ) directly. Instead, the server
maintains a queue of approximate likelihood updates and applies these to the approximate posterior
q (new) (θ) = q (old) (θ)∆m (θ). This setup supports asynchronous updates of the approximate likelihood
factors. See fig. 1 for a pictorial depiction of these steps. In contrast, global VI is generally ill-suited to
the distributed setting. Although the free-energy optimization can be parallelized over data points,
typically this will only be advantageous for large mini-batches where the extra communication overhead
does not dominate. Large mini-batches often result in slow optimization progress (early in learning it is
often clear how to improve q(θ) after seeing only a small number of data points). The special case of
global VI employing mini-batch approximations and natural gradient updates can support asynchronous
distributed processing if each worker receives statistically identical data and updates with the same
frequency. It could not operate successfully when each node contains different amounts or types of data,
or if some workers update more frequently than others.
Distributed versions of PVI not only enable VI to be scaled to large problems, but they also allow
inference algorithms to be sent to user data, rather than requiring user data to be collected and
centralized before performing inference. Consider the situation where workers are personal devices,
like mobile phones, containing user data y m . Here the local free-energy updates can be performed
client-side on the user’s devices and only summaries tm (θ) of the relevant aspects of that information
are communicated back to the central server. The frequency with which these messages are sent might
be limited to improve security. Such an implementation is arguably more secure than one in which the
user data (or associated gradients) are sent back to a central server [The Royal Society, 2017]. Since the
amount and type of data at the nodes is outside of the control of the algorithm designer, mini-batch
natural gradient global VI will generally be inappropriate for this setting.
The PVI approach is also well suited to the continual or life-long learning setting. These settings
are very general forms of online learning in which new data regularly arrive in a potentially non-iid
way, tasks may change over time, and entirely new tasks may emerge. In this situation, the PVI
framework can not only be used to continuously update the posterior distribution q(θ) in light of new
data by optimizing the local free-energy for the newly seen data, it can also be used to revisit old
data groups (potentially in a judiciously selected way) thereby mitigating problems like catastrophic
forgetting. The update steps for this learning scenario are illustrated in fig. 1. In contrast, global
VI is fundamentally ill-suited to the general online setting. The special case of global VI employing
mini-batch approximations with natural gradient updates may be appropriate when the data are iid
and only one update is performed for each new task (simple online learning), but it is not generally
applicable.
We will return to discuss the key issues raised in this section – the speed of convergence, memory
overhead, online learning, and distributed inference – in the context of different options for carrying out
the optimization of the local free-energies in section 3.
7
2.2 Hyperparameter Learning
Many probabilistic models depend on a set of hyperparameters and it is often necessary to learn
suitable settings from data to achieve good performance on a task. One method is to optimize the
variational free-energy thereby approximating maximum likelihood learning. The gradient of the global
variational free-energy decomposes into a set of local computations, as shown in appendix B,
M
d X d d
F(, q(θ)) = Eq(θ) log p(ym |θ, ) + Eq(θ) log p(θ|) . (5)
d d d
m=1
This expression holds for general q(θ) and is valid both for coordinate ascent (updating with q(θ)
fixed) and for optimizing the collapsed bound (where the approximate posterior optimizes the global
free-energy q(θ) = q ∗ (θ) and therefore depends implicitly on ). Notice that this expression is amenable
to stochastic approximation which leads to optimization schemes that use only local information at
each step. When combined with different choices for the optimization of the local free-energies wrt q(θ),
this leads to a wealth of possible hyperparameter optimization schemes.
In cases where a distributional estimate for the hyperparameters is necessary, e.g. in continual
learning, the PVI framework above can be extended to handle the hyperparameters. In particular, the
approximate posterior in eq. (1) can be modified as follows,
M M
Y 1 Y
q(θ, ) = p()p(θ|) tm (θ, ) ≈ p()p(θ|) p(y m |θ, ) = p(θ, |y), , (6)
Z
m=1 m=1
where the approximate likelihood factor tm (θ, ) now involves both the model parameters and the
hyperparameters. Similar to eq. (2), the approximate posterior above leads to the following local
variational free-energy,
Note that this approach retains all favourable properties of PVI such as local computation and flexibility
in choosing optimization strategies and stochastic approximations.
8
3.2 Off-the-shelf Optimizers for Local Free-energy Optimization
If analytic updates are not tractable, the local free-energy optimizations can be carried out using
standard optimizers. The PVI framework automatically breaks the data set into a series of local
free-energy optimization problems and the propagation of uncertainty between the data groups weights
the information extracted from each. This means non-stochastic optimizers such as BFGS can now be
leveraged in the large data setting. Of course, if a further stochastic approximation like Monte Carlo VI
is employed for each local optimization, stochastic optimizers such as RMSProp [Tieleman and Hinton,
2012] or Adam [Kingma and Ba, 2014] might be more appropriate choices. In all cases, since the local free-
(i−1)
energy is equivalent in form to a global free-energy with an effective prior peff (θ) = q (i−1) (θ)/tbi (θ),
PVI can be implemented via trivial modification to existing code for global VI. This is a key advantage
of PVI over previous local VI approaches, such as variational message passing [Winn et al., 2005, Winn
and Minka, 2009, Knowles and Minka, 2011], in which bespoke and closed-form updates are needed for
different likelihoods and cavity distributions.
3.3 Local Free-energy Fixed Point Updates, Natural Gradient Methods, and Mir-
ror Descent
An alternative to using off-the-shelf optimizers is to derive fixed-point update equations by zeroing the
gradients of the local free-energy. These fixed-point updates have elegant properties for approximate
posterior distributions that are in the exponential family.
Property 4 If the prior and approximate likelihood factors are in the un-normalized exponential family
|
tm (θ) = tm (θ; ηm ) = exp(ηm T (θ)) so that the variational distribution is in the normalized exponential
dF (i) (q(θ))
family q(θ) = exp(ηq| T (θ) − A(ηq )), then the stationary point of the local free-energy dηq =0
implies
(i) d
ηbi = C−1 Eq (log p(y bi |θ)). (8)
dηq
d2 A(ηq )
where C := dηq dηq = covq(θ) [T (θ)T | (θ)] is the Fisher Information. Moreover, the Fisher Information
dµq
can be written as C = dηq where µq = Eq (T (θ)) is the mean parameter of q(θ). Hence,
(i) d
ηbi = Eq (log p(y bi |θ)). (9)
dµq
For some approximate posterior distributions q(θ), taking derivatives of the average log-likelihood with
respect to the mean parameters is analytic (e.g. Gaussian) and for some it is not (e.g. gamma).
These conditions, derived in appendix A.4, can be used as fixed point equations. That is, they can be
iterated possibly with damping ρ,
(i) (i−1) d
ηbi = (1 − ρ)ηbi +ρ Eq (log p(y bi |θ)). (10)
dµq
These iterations, which form an inner-loop in PVI, are themselves not guaranteed to converge (there
is no Lyapunov function in general and so, for example, the local free-energy will not reduce at every
step).
The fixed point updates are the natural gradients of the local free-energy and the damped versions
are natural gradient ascent [Sato, 2001, Hoffman et al., 2013]. The natural gradients could also be used in
other optimization schemes [Hensman et al., 2012, Salimbeni et al., 2018]. The damped updates are also
9
equivalent to performing mirror-descent [Raskutti and Mukherjee, 2015, Khan and Lin, 2018], a general
form of proximal algorithm [Parikh and Boyd, 2014] that can be interpreted as trust-region methods.
For more details about the relationship between these methods, see appendix A.7. Additionally, while
natural gradients or fixed-point updates have been shown to be effective in the batch global VI settings
[see e.g. Honkela et al., 2010], we present some result in appendix E.6 showing adaptive first-order
methods employing flat gradients such as Adam [Kingma and Ba, 2014] performs as well as natural
gradient methods, when stochastic mini-batch approximations are used.
For these types of updates there is an interesting relationship between PVI and global (batch) VI:
Property 5 PVI methods employing parallel updates result in identical dynamics for q(θ) given by the
following equation, regardless of the partition of the data employed
N
d X d
ηq(i) = η0 + Eq (log p(y|θ)) = η0 + Eq(i−1) (log p(yn |θ)). (11)
dµq(i−1) dµq(i−1)
n=1
See A.5 for the proof. If parallel fixed-point updates are desired, then it is more memory efficient to
employ batch VI M = 1, since then only one global set of natural parameters needs to be retained.
However, as previously discussed, using M = 1 gives up opportunities for online learning and distributed
computation (e.g. asynchronous updates).
10
3.4.2 Stochastic Scheduling of Updates Between Local Free-Energies
The second form of stochastic approximation is to randomize the update schedule. For example, using
M = N and randomly selecting subsets of data to update in parallel. This can be memory intensive,
requiring N local natural parameters to be stored. A more memory efficient approach is to fix the
mini-batches across epochs and to visit the data groups y m in a random order [Khan and Lin, 2018].
For the simplified fixed point updates, this yields
(i) (i−1) d
ηm = (1 − ρ)ηm +ρ E (i−1) (log p(y m |θ)). (14)
dµq(i−1) q
This approach results in a subtly different update to q that retains a specific approximation to the
likelihood of each data partition, rather than a single global approximation
(i) (i−1) d (i−1)
ηq = ηq −ρ Eq (log p(y m |θ)) − ηm . (15)
dµq
If the first approach in eq. (14) employs learning rates that obey the Robins Munro conditions, the
fixed points will be identical to the second approach in eq. (15) and they will correspond to optima of
the global free-energy.
11
Figure 2: Variational inference schemes encompassed by the PVI framework.
Hoffman et al. [2013] applied the insight to conjugate models when optimizing collapsed variational
free-energies and deriving stochastic natural gradient descent, respectively. Salimans and Knowles
[2013] apply the fixed points to non-conjugate models where the expectations over q are intractable
and use Monte Carlo to approximate them, but they explicitly calculate the Fisher information matrix,
which is unnecessary for exponential family q. Sheth and Khardon [2016a] and Sheth et al. [2015] treat
non-conjugate models with Gaussian latent variables, employ the cancellation of the Fisher information,
and analyze convergence properties. Sheth and Khardon [2016b] further extend this to two level-models
through Monte Carlo essentially applying the Fisher information cancellation to Salimans and Knowles
[2013], but they were unaware of this prior work.
These have the form of standard variational inference with the prior replaced by the previous variational
distribution q (i−1) (θ). This idea – combining the likelihood from a new batch of data with the previous
12
Algorithm 2 One step of the PEP algorithm at the i-th iteration, for the bi -th data partition
α
(i) p(ybi |θ)
Compute the tilted distribution: p̂α (θ) = q (i−1) (θ) (i−1)
tb (θ)
i
(i)
Moment match: qα (θ) = proj(p̂α (θ))
such that Eq(θ) (T (θ)) = Ep̂(i) (θ) (T (θ))
α
1−ρ/α
(i)
Update the posterior distribution with damping ρ: q (θ) = q (i−1) (θ) (qα (θ))ρ/α
(i) q (i) (θ) (i−1)
Update the approximate likelihood: tbi (θ) = t
q (i−1) (θ) bi
(θ)
Algorithm 3 One step of the PEP algorithm, as in algorithm 2, but with alpha divergence minimization
p(ybi |θ)
Compute the tilted distribution: p̂(i) (θ) = q (i−1) (θ) (i−1)
tb (θ)
i
Find the posterior distribution: q (i) (θ) := argminq(θ)∈Q Dα [p̂(i) (θ)||q(θ)]
(i) q (i) (θ) (i−1)
Update the approximate likelihood: tbi (θ) = q (i−1) (θ)
tbi (θ)
approximate posterior and projecting back to a new approximate posterior – underpins online variational
inference [Ghahramani and Attias, 2000, Sato, 2001], streaming variational inference [Broderick et al.,
2013, Bui et al., 2017b], and variational continual learning [Nguyen et al., 2018]. Early work on online VI
used conjugate models and analytic updates [Ghahramani and Attias, 2000, Sato, 2001, Broderick et al.,
2013, Bui et al., 2017b], this was followed by off-the-shelf optimization approaches for non-conjugate
models [Bui et al., 2017b] and further extended to leverage MC approximations of the local-free energy
[Nguyen et al., 2018]. Recently Zeno et al. [2018] use the variational continual learning framework of
Nguyen et al. [2018], but employ fixed-point updates instead.
Property 6 The damped fixed point equations are precisely those returned by the PEP algorithm, shown
in algorithm 2, in the limit that α → 0.
Although we suspect Knowles and Minka [2011] knew of this relationship, and it is well known that
Power EP has the same fixed points as VI in this case, it does not appear to be widely known that
variationally limited Power EP yields exactly the same algorithm as fixed point local VI. See A.8 for
the proof.
13
Algorithm 4 One step of the SPEP algorithm at the i-th iteration, for the bi -th data partition
p(y |θ) α
(i)
Compute the tilted distribution: p̂α (θ) = q (i−1) (θ) t(i−1)
bi
(θ)
(i)
Moment match: qα (θ) = proj(p̂α (θ)) such that Eq(θ) (T (θ)) = Ep̂(i) (θ) (T (θ))
α
1−N ρ/α
(i)
Update the posterior distribution with damping ρ: q (θ) = q (i−1) (θ) (qα (θ))N ρ/α
(i) 1/N
(θ)
Update the approximate likelihood: t(i) = q p(θ)
14
Table 1: Variational inference schemes encompassed by the PVI framework. (See next page.) Selected
past work has been organized into four categories: global VI (PVI with M = 1), fully local PVI (M = N ),
Power EP variants, and online VI. The citation to the work is provided along with the granularity
of the method (global indicates M = 1, fully local M = N , local implies general M can be used).
The optimization used from the PVI perspective on this work is noted. Abbreviations used here are:
Conjugate Gradient (CG) and Monte Carlo (MC). The model class that the scheme encompasses is noted
(conjugate versus non-conjugate) along with the specific models that the scheme was tested on. Model
abbreviations are: Non-linear State-space Model (NSSM), Non-linear Factor Analysis (NFA), Latent
Dirichlet Allocation (LDA), Poisson Mixed Model (PMM), Heteroscedastic Linear Regression (HLR),
Sparse Gaussian Processes (SGPs), Graphical Model (GM), Logistic Regression (LR), Beta-binomial
(BB), Stochastic Volatility model (SV), Probit Regression (PR), Multinomial Regression (MR), Bayesian
Neural Network (BNN), Gamma factor model (GFM), Poisson Gamma Matrix Factorization (PGMF),
Mixture of Gaussians (MoG). Poisson Mixed Model (PMM), Heteroscedastic Linear Regression (HLR),
Gaussian Latent Variable (GLV). If the scheme proposed by the method has a name, this is noted in
the final column. Abbreviations of the inference scheme are: Automatic Differentiation VI (ADVI),
Incremental VI (IVI), Non-conjugate Variational Message Passing (NC-VMP), Simplified NC-VMP
(SNC-VMP), Conjugate-Computation VI (CCVI), Power EP (PEP), Alpha-divergence PEP (ADPEP),
Convergent Power EP (CPEP), Stochastic Power EP (SPEP), Variational Continual Learning (VCL),
Bayesian Gradient Descent (BGD).
Figure 3: The local VI framework unifies prior work. The granularity of the approximation and
the optimization method employed are two fundamental algorithmic dimensions that are shown as
axes. Fixed-point updates are identical to natural gradient ascent with unit step size. The models
encompassed by each paper are indicated by the color. See 1 for more information.
15
Reference Granularity Optimization Models Name
Global VI [PVI M = 1, see section 4.1]
Beal [2003] global analytic conjugate VI
Sato [2001] global analytic conjugate (MoG)
Hinton and Van Camp [1993] global gradient ascent non-conjugate (neural network)
Honkela et al. [2010] global natural gradient (mean only) non-conj. (MoG, NSSM, NFA)
Hensman et al. [2012] global CG with natural gradient conjugate
Hensman et al. [2013] global stochastic natural gradient conjugate
Hoffman et al. [2013] global stochastic natural gradient conjugate SVI
Kucukelbir et al. [2017] global stochastic gradient descent non-conjugate ADVI
Salimans et al. [2013] global fixed-point + MC + stochastic non-conjugate (PR, BB, SV)
Sheth et al. [2015] global simplified fixed point non-conjugate (GLV)
Sheth and Khardon [2016a] global simplified fixed point non-conjugate (GLV)
Sheth and Khardon [2016b] global simplified fixed point + MC non-conjugate (two level)
Fully Local VI [PVI M = N , see section 4.2]
Winn et al. [2005] fully local analytic conjugate (GM) VMP
Archambeau and Ermis [2015] fully local incremental conjugate (LDA) IVI
16
Knowles and Minka [2011] fully local fixed-point non-conjugate (LR, MR) NC-VMP
Wand [2014] fully local simplified fixed-point non-conjugate (PMM, HLR) SNC-VMP
Khan and Lin [2018] local damped stochastic simplified fixed- non-conjugate (LR, GFM, PGMF) CCVI
point
Online VI [one pass of PVI, see section 4.3]
Ghahramani and Attias [2000] fully local analytic conjugate (MoG)
Sato [2001] fully local analytic conjugate (MoG) online VB
Broderick et al. [2013] fully local analytic conjugate (LDA) streaming VI
Bui et al. [2017a] fully local analytic/LBFGS conjugate and not (SGPs)
Nguyen et al. [2018] fully local Adam non-conjugate (BNN) VCL
Zeno et al. [2018] fully local fixed-point non-conjugate (BNN) BGD
Power EP [PVI when α → 0, see sections 4.4 to 4.7]
Minka [2004] local series fixed point non-conjugate (GM) PEP
Minka [2004] local optimization ADEP
Bui et al. [2017b] local analytic/fixed-point conjugate / non-conj. (GPs) PEP
Hasenclever et al. [2017] local analytic with MC non-conjugate (BNN) CPEP
Li et al. [2015] local stochastic fixed point non-conjugate (LR, BNN) SPEP
Table 1: Variational inference schemes encompassed by the PVI framework. See previous page for full caption.
over the unknown parameters θ. Having specified the probability of everything, we turn the handle of
probability theory to obtain the posterior distribution,
QNk
p(θ) K
Q
p(θ)p(y|θ, x) k=1 n=1 p(yk,n |θ, xk,n )
p(θ|x, y) = = . (16)
p(y|x) p(y|x)
The exact posterior above is analytically intractable and thus approximation techniques such as sampling
or deterministic methods are needed. There is a long history of research on approximate Bayesian
training of neural networks, including extended Kalman filtering [Singhal and Wu, 1989], Laplace’s
approximation [MacKay, 2003], Hamiltonian Monte Carlo [Neal, 1993, 2012], variational inference
[Hinton and Van Camp, 1993, Barber and Bishop, 1998, Graves, 2011, Blundell et al., 2015, Gal
and Ghahramani, 2016], sequential Monte Carlo [de Freitas et al., 2000], expectation propagation
[Hernández-Lobato and Adams, 2015], and approximate power EP [Li et al., 2015, Hernández-Lobato
et al., 2016]. In this section, we focus on Monte Carlo variational inference methods with a mean-field
Gaussian variational approximation Q[Graves, 2011, Blundell et al., 2015]. In detail, a factorized global
2
variational approximation, q(θ) = i N (θi ; µi , σi ), is used to lower-bound the log marginal likelihood
as follows,
Z Z
p(θ)p(y|θ, x)
log p(y|x) = log dθ p(θ)p(y|θ, x) ≥ dθ q(θ) log = FGVI (q(θ)), (17)
q(θ)
where FGVI (q(θ)) is the variational lower bound or the negative variational free-energy. This bound
can be expanded as follows,
Nk Z
K X
X
FGVI (q(θ)) = −KL[q(θ)||p(θ)] + dθ q(θ) log p(yk,n |θ, xk,n ). (18)
k=1 n=1
When the prior is chosen to be a Gaussian, the KL term in the bound above can be computed
analytically. In contrast, the expected log-likelihood term is not analytically tractable. However, it can
be approximated by simple Monte Carlo with the (local) reparameterization trick such that low-variance
stochastic gradients of the approximate expectation wrt the variational parameters {µi , σi } can be
easily obtained [Rezende et al., 2014, Kingma and Welling, 2014, Kingma et al., 2015].
The variational lower-bound above can be optimized using any off-the-shelf stochastic optimizer, and
its gradient computation can be trivially distributed across many machines. A possible synchronously
distributed schedule when using K compute nodes, each having access to a memory shard, is as follows:
(i) a central compute node passes the current q(θ) to K workers, (ii) each worker then computes the
gradients of the expected log-likelihood of (a mini-batch of) its own data and passes the gradients
back to the central node, (iii) the central node aggregates these gradients, combines the result with
the gradient of the KL term, and performs an optimization step to obtain a new q(θ). These steps are
then repeated for a fixed number of iterations or until convergence. However, notice that this schedule
is communication-inefficient, as it requires frequent communication of the gradients and the updated
variational approximation between the central node and the K compute workers. We will next discuss
an inference scheme based on PVI that allows communication efficient updates between workers that is
compatible with various scheduling schemes.
Following the PVI formulation in section 2, the approximate posterior can be rewritten using the
approximate factors, one for each memory shard, as follows,
Y Nk
K Y K
Y
p(θ|x, y) ∝ p(θ) p(yk,n |θ, xk,n ) ≈ p(θ) tk (θ) = q(θ), (19)
k=1 n=1 k=1
17
where tk (θ) approximates the contribution of data points in the k-th shard to the posterior. As discussed
in the previous sections, PVI turns the original global approximate inference task into a collection of
approximate inference tasks, i.e. for the k-th memory shard and k-th compute node, the task is to
maximize,
Nk Z
X
k \k
FPVI (q(θ)) = −KL[q(θ)||q (θ)] + dθ q(θ) log p(yk,n |θ, xk,n ), (20)
n=1
where q \k (θ)= q(θ)/tk (θ) is the context or effective prior set by data points in other shards. Once a
new variational approximation q(θ) is obtained, a new approximate factor can be computed accordingly,
tk (θ) = q(θ)/q \k (θ). Note that the objective for each compute node is almost identical to the GVI
objective, except the prior is now replaced by the context and the data are limited to the compute
node’s accessible data. This means any global VI implementation available on a compute node (either
using optimization, fixed-point updates, or in close-formed) can be trivially modified to handle PVI.
A key additional difference to GVI is the communication frequency between the compute nodes and
the central parameter server (that holds the latest q(θ)): a worker can decide to pass tk (θ) back to the
central server after multiple passes through its data, after one epoch, or after just one mini-batch. This
leaves room for practitioners to choose a learning schedule that meets communication constraints. More
importantly, PVI enables various communication strategies to be deployed, for example:
• Sequential PVI with only one pass through the data set: each worker, in turn, runs Global VI,
with the previous posterior being the prior/context, for the data points in its memory shard and
returns the posterior approximation to the parameter server. This posterior approximation will
then be used as the context for the next worker’s execution. Note that this is exactly equivalent
to Variational Continual Learning [Nguyen et al., 2018] and can be combined with the multihead
architecture, each head handling one task or one worker, or with episodic memory [see e.g. Zenke
et al., 2017, Nguyen et al., 2018]. This strategy is communication-efficient as only a small number
of messages are required — only one up/down update is needed for each worker.
• PVI with synchronous model updates: instead of sequentially updating the context distribution
and running only one worker at a time, all workers can be run in parallel. That is, each worker
occasionally sends its updated contribution to the posterior back to the parameter server. The
parameter server waits for all workers to finish before aggregating the approximate factors and
sending the new posterior back to the workers. The workers will then update their own context
distributions based on the current state of the central parameters. This process then repeats. By
analyzing the homogeneity of the data and updates across workers, heuristics could be used to
choose the learning rate for each worker and damping factor for the central parameter server —
we leave this for future work.
• PVI with lock-free asynchronous updates: instead of waiting for all workers to finish training
locally, the model aggregation and update steps can be performed as soon as any worker has
finished. This strategy is particularly useful when communication is done over an unreliable
channel, the distribution of the data across different machines is highly unbalanced, or when a
machine can be disconnected from the training procedure at any time. However, this strategy
is expected to be generally worse compared the synchronous update scheme above, since the
context/cavity distribution could be changed while a worker is running and the next parameter
update performed by this worker could overwrite the updates made by other workers, i.e. there is
the possibility of stale updates.
We demonstrate these communication strategies on a large-scale federated classification task in
section 7.1 and highlight the advantages and potential pitfalls of PVI, GVI and various alternatives for
different levels of data homogeneity across memory shards.
18
6 Improving Continual Learning for Sparse Gaussian Processes Using
PVI
Gaussian processes (GPs) are flexible probabilistic distributions over functions that have been used in
wide variety of machine learning problems, including supervised learning [Rasmussen and Williams, 2006],
unsupervised learning [Lawrence, 2004] and reinforcement learning [Deisenroth, 2010]. The application
of GPs to more general, large-scale settings is however hindered by analytical and computational
intractabilities. As a result, a large body of active GP research aims to develop efficient approximation
strategies for inference and learning in GP models. In this work, we develop an approximation based
on partitioned variational inference for GP regression and classification in a continual learning setting.
In this setting, data arrive sequentially, either one data point at a time or in batches of a size that is
unknown a priori. An efficient strategy to accurately update the model in an online fashion is thus
needed and can be used for various applications such as control [Nguyen-Tuong et al., 2009] or mapping
[O’Callaghan and Ramos, 2012].
In particular, building on recent work on pseudo-point sparse approximations [Titsias, 2009, Hensman
et al., 2015, Matthews et al., 2016, Bui et al., 2017b] and streaming approximations [Csató and Opper,
2002, Bui et al., 2017a], we develop a streaming variational approximation that approximates the
posterior distribution over both the GP latent function and the hyperparameters for GP regression and
classification models. Additionally, the partitioned VI view of this approximation allows just-in-time,
dynamic allocation of new pseudo-points specific to a data batch, and more efficient training time
and accurate predictions in practice. We will provide a concise review of sparse approximations for
Gaussian process regression and classification before summarizing the proposed continual learning
approach. For interested readers, see Quiñonero-Candela and Rasmussen [2005], Bui et al. [2017a]
for more comprehensive reviews of sparse GPs. Appendix C contains the full derivation of different
streaming variational approaches with shared or private pseudo points, and with maximum likelihood
or variational learning strategies for the hyperparameters.
Given N input and output pairs {xn , yn }N n=1 , a standard GP regression or classification model assumes
the outputs {yn }n=1 are generated from the inputs {xn }N
N
n=1 according to yn = f (xn ) + ξn , where
f is an unknown function that is corrupted by observation noise, for example, ξ ∼ N (0, σy2 ) in the
real-valued output regression problem.4 Typically, f is assumed to be drawn from a zero-mean GP
prior f ∼ GP(0, k(·, ·|)) whose covariance function depends on hyperparameters . We also place
a prior over the hyperparameters and as such inference involves finding the posterior over both f
and , p(f, |y, x), and computing the marginal likelihood p(y|x), where we have collected the inputs
and observations into vectors x = {xn }N N
n=1 and y = {yn }n=1 respectively. This is one key difference
to the work of Bui et al. [2017b], in which only a point estimate of the hyperparameters is learned
via maximum likelihood. The dependence on the inputs of the posterior, marginal likelihood, and
other quantities will be suppressed when appropriate to lighten the notation. Exact inference in the
model considered here is analytically and computationally intractable, due to the non-linear dependency
between f and , and the need to perform a high dimensional integration when N is large.
In this work, we focus on the variational free energy approximation scheme [Titsias, 2009, Matthews
et al., 2016] which is arguably the leading approximation method for many scenarios. This scheme
lower bounds the marginal likelihood of the model using a variational distribution q(f, ) over the latent
4
In this section, f stands for the model parameters, as denoted by θ in the previous sections.
19
function and the hyperparameters:
where F(q(f, )) is the variational surrogate objective and can be maximized to obtain q(f, ). In order
to arrive at a computationally tractable method, the approximate posterior is parameterized via a
set of Ma pseudo-outputs a which are a subset of the function values f = {f6=a , a}. Specifically, the
approximate posterior takes the following structure:
where q(a) and q() are variational distributions over a and respectively, and p(f6=a |a, ) is the
conditional prior distribution of the remaining latent function values. Note that while a and are
assumed to be factorized in the approximate posterior, the dependencies between the remaining latent
function values f6=a and the hyperparameters , and between f6=a themselves are retained due to the
conditional prior. This assumption leads to a critical cancellation that results in a computationally
tractable lower bound as follows:
6=a |a,
Z XXX
p(y|f, , x)p()p(a|))p(f )
F(q(a), q()) = df d q(f ) log
XX X
p(f6=a |a,
XXX
XX )q(a)q()
X
Z
= −KL[q()||p()] − d q()KL[q(a)||p(a|)]
XZ
+ d da dfn q()q(a)p(fn |a, ) log p(yn |fn , , xn ),
n
where fn = f (xn ) is the latent function value at xn . Most terms in the variational lower bound
above require computation of an expectation wrt the variational approximation q(), which is not
available in closed-form even when q() takes a simple form such as a diagonal Gaussian. However, these
expectations can be approximated by simple Monte Carlo with the reparameterization trick [Kingma
and Welling, 2014, Rezende et al., 2014]. The remaining expectations can be handled tractably, either
in closed-form or by using Gaussian quadrature.
where q1 (a) = t1 (a)g1 (a), q1 () = p()t1 ()g1 (), and t1 (·)s and g1 (·)s are introduced to approximate
the contribution of p(a|) and p(y1 |f, , x1 ) to the posterior, respectively. Note that the last equation
is identical to eq. (21), but the factor representation above facilitates inference using PVI and allows
20
more flexible approximations in the streaming setting. In particular, the exact posterior when both old
data y1 and newly arrived data y2 are included can be approximated in a similar fashion,
where b are new pseudo-outputs, and t2 (·)s and g2 (·)s are approximate contributions of p(b|a, ) and
p(y2 |f, , x2 ) to the posterior, respectively. As we have reused the approximate factors t1 (a) and g1 (a),
the newly introduced pseudo-points b can be thought of as pseudo-points private to the new data. This
is the key difference of this work compared to the approach of Bui et al. [2017a], in which both old
and new data share a same set of pseudo-points. The advantages of the approach based on private
pseudo-points are potentially two-fold: (i) it is conceptually simpler to focus the approximation effort to
handle the new data points while keeping the approximation for previous data fixed, as a new data batch
may require only a small number of representative pseudo-points, and (ii) the number of parameters
(variational parameters and private pseudo-inputs) is much smaller, leading to arguably easier problem
to initialize and optimize.
The approximate factors t2 (·) and g2 (·) can be found by employing the PVI algorithm in section 2.
Alternatively, in the continual learning setting where data points do not need to be revisited, we can
convert the factor-based variational approximation above to a global variational approximation,
p(f, |y1 , y2 , x1 , x2 ) ≈ p()p(f6=a,b |a, b, )t2 (b|a)t2 ()t1 (a)t1 ()g1 (a)g1 ()g2 (b)g2 ()
= p(f6=a,b |a, b, )q2 (b|a)q1 (a)q2 ()
where q2 (b|a) = t2 (b|a)g2 (b), q2 () = p()t1 ()g1 ()t2 ()g2 (), and q2 (b|a) and q2 () are parameterized
and optimized, along with the location of b. While this does not change the fixed-point solution
compared to the PVI algorithm, it allows existing sparse global VI implementations such as that in
GPflow [Matthews et al., 2017] to be easily extended and deployed.
7 Experiments
Having discussed the connections to the literature and developed two novel applications of PVI, we
validate the proposed methods by running a suite of continual and federated learning experiments on
Bayesian neural network and Gaussian process models.
21
communication between workers is managed using Ray [Moritz et al., 2017]. We use Adam [Kingma
and Ba, 2014] for the inner loop optimization for partitioned, distributed methods or the outer loop
optimization for global VI, and mini-batches of 200 data points. In the next few paragraphs, we briefly
detail the methods compared in this section and their results.
Global VI We first evaluate global VI with a diagonal Gaussian variational approximation for the
weights in the neural network. In particular, it is assumed that there is only one compute node (with
either one core or ten cores) that can access the entire data set. This compute node maintains a
global variational approximation to the exact posterior, and adjusts this approximation using the noisy
gradients of the variational free-energy in eq. (18). We simulate the data distribution by sequentially
showing mini-batches that can potentially have all ten classes (iid) or that have data of only one class
(non-iid). Figures 13 and 14 in the appendix show the full performance statistics on the test set during
training for different learning rates and data homogeneity levels. The performance depends strongly on
the learning rate, especially when the mini-batches are non-iid. Faster convergence early in training does
not guarantee a better eventual model as measured by test performance, for the iid setting. Note that
GVI for the non-iid setting can arrive at a good test error rate, albeit requiring a much smaller learning
rate and a substantially larger training time. In addition, global VI is not communication-efficient, as
the global parameters are updated as often as data mini-batches are considered. The best performing
method for the iid/non-iid settings is selected from all of the learning rates considered and they are
shown in figs. 4 and 5.
Bayesian committee machine The Bayesian committee machine (BCM) is a simple baseline which
is naturally applicable to partitioned data [Tresp, 2000]. The BCM performs (approximate) inference
for each data shard independently of other data shards and aggregate the sub-posteriors at the end.
In particular, global VI with a diagonal Gaussian variational approximation is applied independently
to the data in each shard yielding approximate local posteriors {qk (θ)}K k=1 . The aggregation step
involves multiplying K Gaussian densities. The only shared information across different members of
the committee is the prior. This baseline, therefore, assesses the benefits from coordination between the
workers. We consider two prior sharing strategies as follows,
K QK QK
k=1 [p(θ)p(yk |θ, xk )] k=1 qk (θ)
Y
BCM — same: p(θ|x, y) ∝ p(θ) p(yk |θ, xk ) = ≈ ,
[p(θ)]K−1 [p(θ)]K−1
k=1
YK K
Y K
Y
BCM — split: p(θ|x, y) ∝ p(θ) p(yk |θ, xk ) = [[p(θ)]Nk /N p(yk |θ, xk )] ≈ qk (θ).
k=1 k=1 k=1
BCM is fully parallelizable (one worker independently performs inference for one shard) and is
communication-efficient (only one round of communication is required at the end of the training).
However, there are several potential disadvantages: (i) it is not clear whether the prior sharing schemes
discussed above will over-regularize or under-regularize the network compared to the original batch
training scheme, and (ii) since each shard develops independent approximations and the model is
unidentifiable, it is unclear if the simple combination rules above are appropriate. For example, different
members of the committee might learn equivalent posteriors up to a permutation of the hidden units.
Although initializing each approximate posterior qk (θ) in the same way can mitigate this effect, the
lack of a shared context is likely to be problematic. We evaluate BCM for both iid and non-iid settings,
with different learning rates for the Adam optimizer for each worker and show the full results in figs. 15
and 16. It is perhaps surprising that BCM works well in the iid data setting, although the best error
rate of 4% is still much higher than state-of-the-art results on the MNIST classification task (∼ 1%).
However, the concern above about the potential pitfalls when multiplying different sub-posteriors is
22
validated in the non-iid setting, and in the iid setting when each worker is trained for a long time before
performing the aggregation step. The best results in both settings are selected and shown in figs. 4
and 5.
• Sequential PVI with only one pass through the data set: The number of training epochs and
learning rates for each worker are varied, and the full results are included in figs. 17 and 18.
The results show that while this strategy is effective for the iid setting, it performs poorly in
the non-iid setting. This issue is known in the continual learning literature where incremental
learning of a single-head network is known to be challenging. Episodic memory [see e.g. Zenke
et al., 2017, Nguyen et al., 2018] or generative replay [Shin et al., 2017] is typically used to address
this problem. The performance for the best hyperparameter settings are shown in figs. 4 and 5.
• PVI with synchronous model updates: In this experiment, each worker runs one epoch of Global
VI between message passing steps, and the parameter server waits for all workers to finish before
aggregating information and sending it back to the workers. We explore different learning rates
for the inner loop optimization and various damping rates for the parameter server, and show
the full results in figs. 19 and 20. While the performance on the test set depends strongly on the
learning rate and damping factor, if these values are appropriately chosen, this update strategy
can achieve competitive performance (∼< 2%). By analyzing the homogeneity of the data and
updates across workers, some forms of heuristics could be used to choose the learning rate and
damping factor — we leave this for future work. We pick the best performing runs and compare
with other methods in figs. 4 and 5.
• PVI with lock-free asynchronous updates: Similar to the synchronous PVI experiment, we vary
the learning rate and damping factor and include the full results in figs. 22 and 23. The test
performance of this method is generally worse compared the synchronous update scheme, since the
context/cavity distribution could be changed while a worker is running and the next parameter
update performed by this worker could overwrite the updates made by other workers. While we
do not simulate conditions that favour this scheduling scheme such that unreliable communication
channels or unbalanced data across memory shards, we expect this strategy to perform well
compared to other methods in these scenarios. The best hyperparameters are selected and their
performance are shown in figs. 4 and 5.
Discussion The best performance for each method discussed above are shown in figs. 4 and 5,
demonstrating the accuracy-training time and accuracy-communication cost frontiers. In the iid data
setting (fig. 4), distributed training methods can achieve comparable performance in the same training
time compared to that of global VI. However, methods based on data partitioning are much more
communication-efficient, for example, PVI-sync uses about 10 times fewer messages than GVI when both
methods attain a 3% test error. The results for PVI-seq with one pass through the data demonstrates
its efficiency, but highlights the need to revisit data multiple times to obtain better error rate and
log-likelihood. BCM shows promising performance, but is outperformed by all other methods, suggesting
that communication between workers and setting the right approximate prior (context) for each partition
are crucial.
Figure 5 shows the non-iid data regime is substantially more challenging, as simple training methods
including BCM and PVI-seq with one pass perform poorly and other distributed methods require more
23
extreme hyperparameter settings (e.g. much smaller learning rate and higher damping factor), much
longer training time, and higher communication cost to obtain a performance comparable to that in the
iid regime. We note that the performance of PVI is significantly better than a recent result by Zhao et al.
[2018], who achieved a 10% error rate on the same non-iid data setting. Moreover, unlike this work, we
use a fully-connected neural network (rather than a convolutional one) and do not communicate data
between the workers (no data synchronization). As in the iid setting, the performance of PVI-async is
hindered by stale updates, compared to PVI-sync, despite early faster convergence. While GVI with 10
cores is the best performer in terms of predictive performance, it is the least communication-efficient due
to the need to frequently pass gradients between the central parameter server and compute nodes. This,
however, suggests that the performance of PVI could be further improved by more frequent updates
between workers, essentially trading off the communication cost for more accurate prediction.
24
100 100
BCM - split prior, 10 subsets
BCM - same prior, 10 subsets
GVI - 1 core
GVI - 10 cores
PVI - seq, one pass, 10 workers
PVI - async, 10 workers
PVI - sync, 10 workers
error /%
error /%
10 10
2 2
GVI - 10 cores
100 100
PVI - seq, one pass, 10 workers
PVI - async, 10 workers
PVI - sync, 10 workers
nll /nat
nll /nat
10−1 10 1
100 101 102 103 104 100 101 102 103 104 105
train time /s [worker ! server] communication messages
(a) Error and NLL vs train time (b) Error and NLL vs communication cost
Figure 4: Performance on the test set in the federated MNIST experiment with an iid distribution of training points across ten workers. The test performance
is measured using the classification error [error] and the negative log-likelihood [nll], and for both measures, lower is better. All methods are assessed using the
performance vs train time and performance vs communication cost plots — closer to the bottom left of the plots is better. Methods used for benchmarking are:
Bayesian Committee Machines (BCM) with the standard Normal prior [same] and with a weakened prior [split], Global VI (GVI) with one and ten compute cores, PVI
with sequential updates and only one pass through the data [equivalent to Variational Continual Learning], PVI with lock-free asynchronous updates (PVI - async),
and PVI with synchronous updates (PVI - sync). For ease of presentation, the x-axes for the plots start at 1. See text for more details. Best viewed in colour.
26
(a) Error and NLL vs train time (b) Error and NLL vs communication cost
Figure 5: Performance on the test set in the federated MNIST experiment with a non-iid distribution of training points across ten workers, i.e. each worker has access
to digits of only one class. The test performance is measured using the classification error [error] and the negative log-likelihood [nll], and for both measures, lower is
better. All methods are assessed using the performance vs train time and performance vs communication cost plots — closer to the bottom left of the plots is better.
Methods used for benchmarking are: Bayesian Committee Machines (BCM) with the standard Normal prior [same] and with a weakened prior [split], Global VI (GVI)
with one and ten compute cores, PVI with sequential updates and only one pass through the data [equivalent to Variational Continual Learning], PVI with lock-free
asynchronous updates (PVI - async), and PVI with synchronous updates (PVI - sync). For ease of presentation, the x-axes for the plots start at 1. See text for more
details. Best viewed in colour.
7.2 Improving Continual Learning for Sparse Gaussian Processes
We evaluate the performance of the continual learning method for sparse Gaussian process models
discussed in section 6 on a toy classification problem and a real-world regression problem. The different
inference strategies were implemented in Tensorflow [Abadi et al., 2016] and GPflow [Matthews et al.,
2017].
27
Figure 6: Experimental results on a toy streaming data set: predictions after sequentially observing
different data batches [left] and corresponding hyperparameter estimates [right]. Three methods were
considered for this task: MCMC for both the latent function (f ) and hyperparameters (hypers) with no
sparsification, variational inference for both f and hypers with inducing points, and sparse variational
inference for f and maximum likelihood estimation for hypers. Best viewed in colour.
with shared pseudo-points has more pseudo-points for earlier batches and is expected to perform better
theoretically, this experiment shows its inferior performance due to the need to reinitialize and optimize
all pseudo-points at every batch. The full GP with limited memory approach performs poorly and
exhibits forgetting as more batches arrived and old data are excluded from the memory. In addition,
we also tried the streaming inference scheme of Bui et al. [2017a], which only retains a point estimate
of the hyperparameters after each batch, but this did not perform well, demonstrating that being
distributional over the hyperparameters is crucial.
8 Conclusion
This paper provided a general and unifying view of variational inference, Partitioned Variational
Inference, for probabilistic models. We showed that the PVI framework flexibly subsumes many existing
variational inference methods as special cases and allows a wealth of techniques to be connected. We
also demonstrated how PVI allows novel algorithmic developments and practical applications. This is
illustrated through the development of a streaming variational inference scheme for Gaussian process
models and a communication efficient algorithm for training Bayesian neural networks on federated
data.
28
streaming, private u streaming, shared u full, limited mem.
DOF = 1 DOF = 2 DOF = 3 DOF = 4
1.2 0.5
1.2 1.2
1.0 0.4 1.0 1.0
0.8 0.3 0.8 0.8
SMSE
1.5
0.75
1.0 1.0
0.50
0.5 0.5
0.25
0 10 20 0 10 20 0 10 20
batch batch batch
Figure 7: Experimental results of learning inverse dynamics of a robot arm, i.e. predicting the forces
applied to the joints of the arm given the locations, speeds, and accelerations of the joints. Three
methods were considered for this task: streaming sparse variational GP with private pseudo-points,
streaming sparse variational GP with shared pseudo-points, and full GP with limited memory. Full
details are included in the text. Best viewed in colour.
One of the key contributions of this work is the connection of deterministic local message passing
methods with global optimization-based schemes. Each of these different branches of approximate
inference has been arguably developed exclusively of each other, for example, existing probabilistic
programming toolkits tend to work primarily with one approach but not both [see e.g. Tran et al., 2017,
Minka et al., 2018]. This paper suggests these methods are inter-related and practitioners could benefit
from a unified framework, i.e. there are ways to expand the existing probabilistic programming packages
to gracefully handle both approaches. Additionally, the PVI framework could be used to automatically
choose a granularity level and an optimization scheme that potentially offer a better inference method
for the task or the model at hand. It is, however, unclear how flexible variational approximations such
as mixtures of exponential family distributions [see e.g. Sudderth et al., 2010] or normalizing flows
[Rezende and Mohamed, 2015] can be efficiently and tractably accommodated in the PVI framework.
We leave these directions as future work.
The experiments in section 7.1 demonstrated that PVI is well-suited to learning with decentralized
29
data. Deployment of PVI in this setting is practical as its implementation only requires a straightforward
modification of existing global VI implementations. We have also explored how this algorithm allows
data parallelism — each local worker stores a complete copy of the model — and communication efficient,
uncertainty-aware updates between workers. A potential future extension of the proposed approach is
model parallelism. That is, in addition to decentralizing the data and computation across the workers,
the model itself is partitioned. As commonly done in many deep learning training algorithms, model
parallelism could be achieved by assigning the parameters (and computation) of different layers of the
network to different devices. Another potential research direction is coordinator-free, peer-to-peer only
communication between workers. This could be achieved by a worker passing each update to several
randomly selected other workers, who then apply the changes, rather than to a central parameter server.
References
Martín Abadi, Paul Barham, Jianmin Chen, Zhifeng Chen, Andy Davis, Jeffrey Dean, Matthieu
Devin, Sanjay Ghemawat, Geoffrey Irving, Michael Isard, et al. Tensorflow: a system for large-scale
machine learning. In Proceedings of the 12th USENIX Symposium on Operating Systems Design and
Implementation, 2016.
Cedric Archambeau and Beyza Ermis. Incremental variational inference for latent Dirichlet allocation.
arXiv preprint arXiv:1507.05016, 2015.
David Barber and Christopher M. Bishop. Ensemble learning in Bayesian neural networks. In Neural
Networks and Machine Learning, 1998.
Matthew James Beal. Variational algorithms for approximate Bayesian inference. PhD thesis, UCL,
2003.
David M Blei, Andrew Y Ng, and Michael I Jordan. Latent dirichlet allocation. Journal of Machine
Learning Research, 3(Jan):993–1022, 2003.
Charles Blundell, Julien Cornebise, Koray Kavukcuoglu, and Daan Wierstra. Weight uncertainty in
neural network. In International Conference on Machine Learning, pages 1613–1622, 2015.
Keith Bonawitz, Vladimir Ivanov, Ben Kreuter, Antonio Marcedone, H Brendan McMahan, Sarvar
Patel, Daniel Ramage, Aaron Segal, and Karn Seth. Practical secure aggregation for privacy-
preserving machine learning. In Proceedings of the 2017 ACM SIGSAC Conference on Computer and
Communications Security, pages 1175–1191. ACM, 2017.
Tamara Broderick, Nicholas Boyd, Andre Wibisono, Ashia C. Wilson, and Michael I. Jordan. Streaming
variational Bayes. In Advances in Neural Information Processing Systems, 2013.
Thang D. Bui, Cuong V. Nguyen, and Richard E. Turner. Streaming sparse Gaussian process approxi-
mations. In Advances in Neural Information Processing Systems, 2017a.
Thang D. Bui, Josiah Yan, and Richard E. Turner. A unifying framework for Gaussian process pseudo-
point approximations using power expectation propagation. Journal of Machine Learning Research,
18(104):1–72, 2017b.
Arslan Chaudhry, Puneet K Dokania, Thalaiyasingam Ajanthan, and Philip HS Torr. Riemannian walk
for incremental learning: Understanding forgetting and intransigence. arXiv preprint arXiv:1801.10112,
2018.
30
Jianmin Chen, Xinghao Pan, Rajat Monga, Samy Bengio, and Rafal Jozefowicz. Revisiting distributed
synchronous SGD. arXiv preprint arXiv:1604.00981, 2016.
Lehel Csató and Manfred Opper. Sparse online Gaussian processes. Neural Computation, 2002.
Nando de Freitas, Mahesan Niranjan, Andrew H. Gee, and Arnaud Doucet. Sequential Monte Carlo
methods to train neural network models. Neural Computation, 2000.
Jeffrey Dean, Greg Corrado, Rajat Monga, Kai Chen, Matthieu Devin, Mark Mao, Andrew Senior,
Paul Tucker, Ke Yang, Quoc V Le, et al. Large scale distributed deep networks. In Advances in
Neural Information Processing Systems, pages 1223–1231, 2012.
Marc Peter Deisenroth. Efficient reinforcement learning using Gaussian processes. PhD thesis, University
of Cambridge, 2010.
Cynthia Dwork and Aaron Roth. The algorithmic foundations of differential privacy. Foundations and
Trends R in Theoretical Computer Science, 9(3–4):211–407, 2014.
Timo Flesch, Jan Balaguer, Ronald Dekker, Hamed Nili, and Christopher Summerfield. Comparing
continual task learning in minds and machines. Proceedings of the National Academy of Sciences, 115
(44):E10313–E10322, 2018.
Yarin Gal and Zoubin Ghahramani. Dropout as a Bayesian approximation: Representing model
uncertainty in deep learning. In International Conference on Machine Learning, 2016.
Andrew Gelman, Aki Vehtari, Pasi Jylänki, Christian Robert, Nicolas Chopin, and John P Cunningham.
Expectation propagation as a way of life. arXiv preprint arXiv:1412.4869, 2014.
Zoubin Ghahramani and H. Attias. Online variational Bayesian learning. In NIPS Workshop on Online
Learning, 2000.
Xavier Glorot and Yoshua Bengio. Understanding the difficulty of training deep feedforward neural
networks. In International Conference on Artificial Intelligence and Statistics, pages 249–256, 2010.
Ian J. Goodfellow, Mehdi Mirza, Da Xiao, Aaron Courville, and Yoshua Bengio. An empirical
investigation of catastrophic forgetting in gradient-based neural networks. In International Conference
on Learning Representations, 2014.
Alex Graves. Practical variational inference for neural networks. In Advances in Neural Information
Processing Systems, 2011.
Leonard Hasenclever, Stefan Webb, Thibaut Lienart, Sebastian Vollmer, Balaji Lakshminarayanan,
Charles Blundell, and Yee Whye Teh. Distributed bayesian learning with stochastic natural gradient
expectation propagation and the posterior server. Journal of Machine Learning Research, 18(106):
1–37, 2017. URL http://jmlr.org/papers/v18/16-478.html.
Tyler L Hayes, Ronald Kemker, Nathan D Cahill, and Christopher Kanan. New metrics and experimental
paradigms for continual learning. In Proceedings of the IEEE Conference on Computer Vision and
Pattern Recognition Workshops, pages 2031–2034, 2018.
James Hensman, Magnus Rattray, and Neil D. Lawrence. Fast variational inference in the conjugate
exponential family. In Advances in Neural Information Processing Systems, pages 2888–2896, 2012.
James Hensman, Nicolo Fusi, and Neil D Lawrence. Gaussian processes for big data. In Uncertainty in
Artificial Intelligence, page 282, 2013.
31
James Hensman, Alexander G. D. G. Matthews, and Zoubin Ghahramani. Scalable variational Gaussian
process classification. In International Conference on Artificial Intelligence and Statistics, 2015.
José Miguel Hernández-Lobato and Ryan P. Adams. Probabilistic backpropagation for scalable learning
of Bayesian neural networks. In International Conference on Machine Learning, 2015.
José Miguel Hernández-Lobato, Yingzhen Li, Mark Rowland, Daniel Hernández-Lobato, Thang D.
Bui, and Richard E. Turner. Black-box α-divergence minimization. In International Conference on
Machine Learning, 2016.
Geoffrey E Hinton and Drew Van Camp. Keeping the neural networks simple by minimizing the
description length of the weights. In Conference on Computational Learning Theory, pages 5–13,
1993.
Matthew D Hoffman, David M Blei, Chong Wang, and John Paisley. Stochastic variational inference.
Journal of Machine Learning Research, 14(1):1303–1347, 2013.
Antti Honkela, Tapani Raiko, Mikael Kuusela, Matti Tornio, and Juha Karhunen. Approximate
Riemannian conjugate gradient learning for fixed-form variational bayes. Journal of Machine Learning
Research, 11(Nov):3235–3268, 2010.
Tommi S Jaakkola and Michael I Jordan. Improving the mean field approximation via the use of
mixture distributions. In Learning in graphical models, pages 163–173. Springer, 1998.
Mohammad E Khan, Pierre Baqué, François Fleuret, and Pascal Fua. Kullback-Leibler proximal
variational inference. In Advances in Neural Information Processing Systems, pages 3402–3410, 2015.
Mohammad Emtiyaz Khan, Reza Babanezhad, Wu Lin, Mark Schmidt, and Masashi Sugiyama. Faster
stochastic variational inference using proximal-gradient methods with general divergence functions.
In Conference on Uncertainty in Artificial Intelligence, pages 319–328, 2016.
Diederik P Kingma and Jimmy Ba. Adam: A method for stochastic optimization. In International
Conference on Learning Representations, 2014.
Diederik P Kingma and Max Welling. Auto-encoding variational Bayes. In International Conference on
Learning Representations, 2014.
Diederik P Kingma, Tim Salimans, and Max Welling. Variational dropout and the local reparameteri-
zation trick. In Advances in Neural Information Processing Systems, pages 2575–2583, 2015.
James Kirkpatrick, Razvan Pascanu, Neil Rabinowitz, Joel Veness, Guillaume Desjardins, Andrei A.
Rusu, Kieran Milan, John Quan, Tiago Ramalho, Agnieszka Grabska-Barwinska, Demis Hassabis,
Claudia Clopath, Dharshan Kumaran, and Raia Hadsell. Overcoming catastrophic forgetting in
neural networks. Proceedings of the National Academy of Sciences, 2017.
David A. Knowles and Tom Minka. Non-conjugate variational message passing for multinomial and
binary regression. In Advances in Neural Information Processing Systems, pages 1701–1709, 2011.
Alp Kucukelbir, Dustin Tran, Rajesh Ranganath, Andrew Gelman, and David M Blei. Automatic
differentiation variational inference. Journal of Machine Learning Research, 18(1):430–474, 2017.
32
Neil D Lawrence. Gaussian process latent variable models for visualisation of high dimensional data. In
Advances in Neural Information Processing Systems, pages 329–336, 2004.
Yingzhen Li, José Miguel Hernández-Lobato, and Richard E. Turner. Stochastic expectation propagation.
In Advances in Neural Information Processing Systems, pages 2323–2331, 2015.
Zhizhong Li and Derek Hoiem. Learning without forgetting. In European Conference on Computer
Vision, 2016.
Vincenzo Lomonaco and Davide Maltoni. Core50: a new dataset and benchmark for continuous object
recognition. In Conference on Robot Learning, pages 17–26, 2017.
David JC MacKay. Information theory, inference and learning algorithms. Cambridge University Press,
2003.
Alexander G De G Matthews, Mark Van Der Wilk, Tom Nickson, Keisuke Fujii, Alexis Boukouvalas,
Pablo León-Villagrá, Zoubin Ghahramani, and James Hensman. GPflow: A Gaussian process library
using TensorFlow. Journal of Machine Learning Research, 18(1):1299–1304, 2017.
Michael McCloskey and Neal J. Cohen. Catastrophic interference in connectionist networks: The
sequential learning problem. Psychology of Learning and Motivation, 1989.
Brendan McMahan, Eider Moore, Daniel Ramage, Seth Hampson, and Blaise Aguera y Arcas.
Communication-efficient learning of deep networks from decentralized data. In International Confer-
ence on Artificial Intelligence and Statistics, pages 1273–1282, 2017.
Lars Mescheder, Sebastian Nowozin, and Andreas Geiger. Adversarial variational bayes: Unifying
variational autoencoders and generative adversarial networks. In International Conference on Machine
Learning, pages 2391–2400, 2017.
T. Minka, J.M. Winn, J.P. Guiver, Y. Zaykov, D. Fabian, and J. Bronskill. /Infer.NET 0.3, 2018.
Microsoft Research Cambridge. http://dotnet.github.io/infer.
Thomas P. Minka. Power EP. Technical report, Microsoft Research Cambridge, 2004.
Philipp Moritz, Robert Nishihara, Stephanie Wang, Alexey Tumanov, Richard Liaw, Eric Liang, William
Paul, Michael I Jordan, and Ion Stoica. Ray: A distributed framework for emerging AI applications.
arXiv preprint arXiv:1712.05889, 2017.
Radford M. Neal. Bayesian learning via stochastic dynamics. In Advances in Neural Information
Processing Systems, pages 475–482, 1993.
Radford M Neal. Bayesian learning for neural networks, volume 118. Springer Science & Business
Media, 2012.
Cuong V. Nguyen, Yingzhen Li, Thang D. Bui, and Richard E. Turner. Variational continual learning.
In International Conference on Learning Representations, 2018.
33
Duy Nguyen-Tuong, Jan R Peters, and Matthias Seeger. Local Gaussian process regression for real
time online model learning. In Advances in Neural Information Processing Systems, pages 1193–1200,
2009.
Manfred Opper. Online learning in neural networks. chapter A Bayesian Approach to Online Learning,
pages 363–378. Cambridge University Press, 1998.
Simon T O’Callaghan and Fabio T Ramos. Gaussian process occupancy maps. The International
Journal of Robotics Research, 31(1):42–62, 2012.
Neal Parikh and Stephen Boyd. Proximal algorithms. Foundations and Trends R in Optimization, 1(3):
127–239, 2014.
Joaquin Quiñonero-Candela and Carl E. Rasmussen. A unifying view of sparse approximate Gaussian
process regression. Journal of Machine Learning Research, 2005.
Rajesh Ranganath, Dustin Tran, and David Blei. Hierarchical variational models. In International
Conference on Machine Learning, pages 324–333, 2016.
Garvesh Raskutti and Sayan Mukherjee. The information geometry of mirror descent. IEEE Transactions
on Information Theory, 61(3):1451–1457, March 2015.
Carl E. Rasmussen and Christopher K. I. Williams. Gaussian Processes for Machine Learning. The
MIT Press, 2006.
Roger Ratcliff. Connectionist models of recognition memory: Constraints imposed by learning and
forgetting functions. Psychological Review, 1990.
Danilo Rezende and Shakir Mohamed. Variational inference with normalizing flows. In International
Conference on Machine Learning, pages 1530–1538, 2015.
Danilo Jimenez Rezende, Shakir Mohamed, and Daan Wierstra. Stochastic backpropagation and
approximate inference in deep generative models. In International Conference on Machine Learning,
pages 1278–1286, 2014.
Tim Salimans and David A Knowles. Fixed-form variational posterior approximation through stochastic
linear regression. Bayesian Analysis, 8(4):837–882, 2013.
Tim Salimans, David A Knowles, et al. Fixed-form variational posterior approximation through
stochastic linear regression. Bayesian Analysis, 8(4):837–882, 2013.
Tim Salimans, Diederik Kingma, and Max Welling. Markov chain Monte Carlo and variational inference:
Bridging the gap. In International Conference on Machine Learning, pages 1218–1226, 2015.
Hugh Salimbeni, Stefanos Eleftheriadis, and James Hensman. Natural gradients in practice: Non-
conjugate variational inference in Gaussian process models. In International Conference on Artificial
Intelligence and Statistics, 2018.
Masa-Aki Sato. Online model selection based on the variational Bayes. Neural Computation, 2001.
Jeffrey C. Schlimmer and Douglas Fisher. A case study of incremental concept induction. In The
National Conference on Artificial Intelligence, 1986.
Steven L Scott, Alexander W Blocker, Fernando V Bonassi, Hugh A Chipman, Edward I George, and
Robert E McCulloch. Bayes and big data: The consensus Monte Carlo algorithm. International
Journal of Management Science and Engineering Management, 11(2):78–88, 2016.
34
Ari Seff, Alex Beatson, Daniel Suo, and Han Liu. Continual learning in generative adversarial nets.
arXiv:1705.08395, 2017.
Rishit Sheth and Roni Khardon. A fixed-point operator for inference in variational Bayesian latent
Gaussian models. In Arthur Gretton and Christian C. Robert, editors, International Conference on
Artificial Intelligence and Statistics, volume 51 of Proceedings of Machine Learning Research, pages
761–769, Cadiz, Spain, 09–11 May 2016a. PMLR.
Rishit Sheth and Roni Khardon. Monte carlo structured SVI for non-conjugate models. arXiv preprint
arXiv:1612.03957, 2016b.
Rishit Sheth, Yuyang Wang, and Roni Khardon. Sparse variational inference for generalized GP models.
In Francis Bach and David Blei, editors, International Conference on Machine Learning, volume 37,
pages 1302–1311, 2015.
Hanul Shin, Jung Kwon Lee, Jaehong Kim, and Jiwon Kim. Continual learning with deep generative
replay. In Advances in Neural Information Processing Systems, pages 2990–2999, 2017.
Sharad Singhal and Lance Wu. Training multilayer perceptrons with the extended Kalman algorithm.
In Advances in Neural Information Processing Systems, 1989.
Alex J. Smola, S.V.N. Vishwanathan, and Eleazar Eskin. Laplace propagation. In Advances in Neural
Information Processing Systems, 2004.
Erik B Sudderth, Alexander T Ihler, Michael Isard, William T Freeman, and Alan S Willsky. Nonpara-
metric belief propagation. Communications of the ACM, 53(10):95–103, 2010.
Richard S. Sutton and Steven D. Whitehead. Online learning with random representations. In
International Conference on Machine Learning, 1993.
The Royal Society. Machine learning: The power and promise of computers that learn by example.
Technical report, The Royal Society, 2017.
Lucas Theis and Matthew D Hoffman. A trust-region method for stochastic variational inference with
applications to streaming data. In International Conference on Machine Learning, 2015.
Tijmen Tieleman and Geoffrey E Hinton. Lecture 6.5—RmsProp: Divide the gradient by a running
average of its recent magnitude. COURSERA: Neural Networks for Machine Learning, 2012.
Dustin Tran, Matthew D. Hoffman, Rif A. Saurous, Eugene Brevdo, Kevin Murphy, and David M. Blei.
Deep probabilistic programming. In International Conference on Learning Representations, 2017.
Richard E. Turner and Maneesh Sahani. Two problems with variational expectation maximisation for
time-series models. In D. Barber, T. Cemgil, and S. Chiappa, editors, Bayesian Time series models,
chapter 5, pages 109–130. Cambridge University Press, 2011.
Matt P. Wand. Fully simplified multivariate Normal updates in non-conjugate variational message
passing. Journal of Machine Learning Research, 15:1351–1369, 2014.
35
Bo Wang and DM Titterington. Lack of consistency of mean field and variational Bayes approximations
for state space models. Neural Processing Letters, 20(3):151–170, 2004.
Xiangyu Wang and David B Dunson. Parallelizing MCMC via Weierstrass sampler. arXiv preprint
arXiv:1312.4605, 2013.
John Winn and Tom Minka. Probabilistic programming with Infer.NET. Machine Learning Summer
School, 2009.
John Winn, Christopher M. Bishop, and Tommi Jaakkola. Variational message passing. Journal of
Machine Learning Research, 6:661–694, 2005.
Friedemann Zenke, Ben Poole, and Surya Ganguli. Continual learning through synaptic intelligence. In
International Conference on Machine Learning, pages 3987–3995, 2017.
Chen Zeno, Itay Golan, Elad Hoffer, and Daniel Soudry. Bayesian gradient descent: Online variational
Bayes learning with increased robustness to catastrophic forgetting and weight pruning, 2018.
Sixin Zhang, Anna E Choromanska, and Yann LeCun. Deep learning with elastic averaging SGD. In
Advances in Neural Information Processing Systems, pages 685–693, 2015.
Yue Zhao, Meng Li, Liangzhen Lai, Naveen Suda, Damon Civin, and Vikas Chandra. Federated learning
with non-IID data, 2018.
36
A Proofs
A.1 Relating Local KL minimization to Local Free-energy Minimization: Proof
of Property 1
Substituting the tilted distribution pb(i) (θ) into the KL divergence yields,
(i−1)
q(θ)Zbi tbi (θ) q (i−1) (θ)p(y bi |θ)
Z Z
(i)
KL q(θ)kb
p (θ) = dθq(θ) log = log Zbi − dθq(θ) log .
q (i−1) (θ)p(y bi |θ) (i−1)
q(θ)tbi (θ)
Y (i)
Y q (i) (θ) (i−1) Y
p(θ) t(i)
m (θ) = p(θ)tbi (θ) t(i)
m (θ) = p(θ) t (θ) t(i−1)
m (θ)
m m6=bi
q (i−1) (θ) bi m6=b i
q (i) (θ) Y
= p(θ) t(i−1)
m (θ) = q (i) (θ).
q (i−1) (θ) m
(b) Let ηq and ηq∗ be the variational parameters of q(θ) and q ∗ (θ) respectively. We can write q(θ) = q(θ; ηq )
and q ∗ (θ) = q(θ; ηq∗ ) = p(θ) m tm (θ; ηq∗ ). First, we show that at convergence, the derivative of the
Q
global free-energies equals the sum of the derivatives of the local free-energies. The derivative of the
local free-energy Fm (q(θ)) w.r.t. ηq is:
37
Summing both sides over all m,
Q
X dFm (q(θ)) d p(y m |θ)
Z
= dθq(θ; ηq ) log Q m ∗
m
dηq ηq =ηq∗ dηq m tm (θ; ηq ) ηq =ηq∗
Q
d p(θ) m p(y m |θ)
Z
= dθq(θ; ηq ) log .
dηq q(θ; ηq∗ ) ηq =ηq∗
Hence,
Q
dF(q(θ)) dq(θ; ηq ) p(θ) m p(y m |θ) X dFm (q(θ))
Z
= dθ log = .
dηq ηq =ηq∗ dηq q(θ; ηq∗ ) ηq =ηq∗ m
dηq ηq =ηq∗
dFm (q(θ))
For all m, since q ∗ (θ) = argmaxq(θ)∈Q Fm (q(θ)), we have dηq ηq =ηq∗
= 0, which implies:
d2
Q
p(θ) m p(y m |θ)
Z
= dθq(θ; ηq ) log .
dηq dηq| q(θ; ηq∗ ) ηq =ηq∗
38
Now consider the Hessian of the global free-energy F(q(θ)):
d2 F(q(θ)) d2
Q
p(θ) m p(y m |θ)
Z
= dθq(θ; ηq ) log
dηq dηq| dηq dηq| q(θ; ηq )
0
d2 q(θ; ηq )
Q
p(θ) m p(y m |θ) d dq(θ; )
Z Z
ηq
:
= dθ | log − dθ | .
dηq dηq q(θ; ηq ) dη dηq
q
Hence,
d2 Fm (q(θ))
For all m, since q ∗ (θ) = argmaxq(θ)∈Q Fm (q(θ)), the Hessian dηq dηq| ηq =ηq∗
is negative definite, which
d2 F (q(θ))
implies that the Hessian dηq dηq| ηq =ηq∗
of the global free-energy is also negative definite. Therefore,
q ∗ (θ) is a maximum of F(q(θ)).
where ηm is the natural parameters and T (θ) is the augmented sufficient statistics. We also assume that
the prior p(θ) and the variational distributions q (i) (θ) are normalized exponential family distributions.
To derive the fixed point updates for the local variational inference algorithm, we consider maximizing
the local variational free energy F (i) (q(θ)) in (2). Assuming that at iteration i, the variational distribution
q (i−1) (θ) obtained from the previous iteration has the form:
and the target distribution q(θ) that we optimize has the form:
39
d A(η )2
where C := dηq dηqq = covq(θ) [T (θ)T | (θ)] is the Fisher Information. Note that from Property 2,
P (i−1) (i) P (i−1)
ηq = η0 + m6=bi ηm + ηbi and ηprev = η0 + m ηm , where η0 is the natural parameters for the
(i)
prior p(θ). Hence, from (25), a fixed point update for Algorithm 1 only needs to update ηbi locally
using:
(i) d
ηbi = C−1 Eq (log p(y bi |θ)). (26)
dηq
dµ
The Fisher Information can be written as C = dηqq where µq = Eq (T (θ)) is the mean parameter of q(θ).
This leads to a cancellation of the Fisher information,
(i) d
ηbi = Eq (log p(y bi |θ)). (27)
dµq
(i) d
ηbi = Eq(i−1) (log p(y bi |θ)). (28)
dµq(i−1)
Here we have explicitly denoted the dependence of the approximate posterior on the iteration number i
as the dynamics of this is the focus. When M = N and a parallel fixed point update of all the natural
(i)
parameters, {ηn }Nn=1 , is performed, the natural parameters of q(θ) are updated as follows,
M M
X X d
ηq(i) = η0 + ηn(i) = η0 + Eq(i−1) (log p(y bi |θ))
dµq(i−1)
m=1 m=1
N
X d
= η0 + Eq(i−1) (log p(yn |θ)). (29)
dµq(i−1)
n=1
Here, in the last line, we have used the fact that the data are independent conditioned on θ. Now
consider application of the fixed-point updates to the batch case (M = 1)
N
d X d
ηq(i) = η0 + η = η0 + Eq (log p(y|θ)) = η0 + Eq(i−1) (log p(y n |θ)) (30)
dµq(i−1) dµq(i−1)
n=1
Therefore the updates for q(θ) are identical in the two cases. Naïve implementation of local-VI requires
M sets of natural parameters to be maintained, whereas parallel updating means this is unnecessary,
but equivalent to fixed-point global VI.
40
Hoffman et al. [2013],
P Sheth and Khardon [2016b] employ stochastic mini-batch approximation of
Eq (log p(y|θ)) = n Eq(θ) (log p(yn |θ)) = N Epdata (y),q(θ) (log p(yn |θ)) by sub-sampling the data distribu-
iid
tion yl ∼ pdata (y) where pdata (y) = N1 N
P
n=1 δ(y − yn ). This approximation yields,
(i) (i−1) d
ηq = (1 − ρ)ηq + ρ η0 + L Eq (log p(y l |θ)) . (32)
dµq
Where y l are a mini-batch containing N/L data points. Li et al. [2015] show that their stochastic power
EP algorithm recovers precisely these same updates when α → 0 (this is related to property 6). The
relation to EP-like algorithms can be made more explicit by writing 32 as
(i) (i−1) 0 d (i−1)
ηq = ηq +ρ Eq (log p(y l |θ)) − ηlike /L . (33)
dµq
(i−1) (i−1) (i−1)
Where ρ0 = ρL is a rescaled learning rate and ηlike = ηq − η0 is the portion of the approximate
(i−1)
posterior natural parameters that approximates the likelihoods. As such, ηlike /L is the contribution a
mini-batch likelihood makes on average to the posterior. So, interpreting the update in these terms,
the new approximate posterior is equal to the old approximate posterior with ρ0 of the average mini-
batch likelihood approximation removed and ρ0 of the approximation from the current mini-batch,
d
dµq Eq (log p(y m |θ)), added in place of it.
Khan and Lin [2018] take a different approach employing damped simplified fixed-point updates
for local VI M = N and then using an update schedule that selects a mini-batch at random and then
updates the local natural parameters for these data-points in parallel. That is for all data points in the
mini-batch they apply
d
ηn(i) = (1 − ρ)ηn(i−1) + ρ Eq (log p(yn |θ)). (34)
dµq
This local update incurs a memory overhead that is N times larger due to the need maintain N sets of
local parameters rather than just one. However, if the mini-batch partition is fixed across epochs, then
it is sufficient to maintain M natural parameters instead, corresponding to one for each mini-batch,
(i) (i−1) d
ηm = (1 − ρ)ηm +ρ Eq (log p(y m |θ)). (35)
dµq
Interestingly, both of these local updates (34 and 35) result in a subtly different update to q as the
stochastic global update above (33),
(i) (i−1) d (i−1)
ηq = ηq −ρ Eq (log p(y m |θ)) − ηm . (36)
dµq
Here the deletion step is explicitly revealed again, but now it involves removing the natural
(i−1)
parameters of the mth approximate likelihood ηm rather than the average mini-batch likelihood
(i−1)
approximation ηlike /L.
A summary of work employing these two types of stochastic approximation is provided in figure 8.
Each variety of update has its own pros and cons (compare 33 to 36). The stochastic global update
33 does not support general online learning and distributed asynchronous updates, but it is more
memory efficient and faster to converge in the batch setting.
Consider applying both methods to online learning where y l or y m correspond to the new data seen
at each stage and ρ0 and ρ are user-determined learning rates for the two algorithms. General online
learning poses two challenges for the stochastic global update 33. First, the data are typically not iid
41
[1] Beal, 2003
[9] [2] Winn et al. 2005
[3] Knowles & Minka 2004
stochastic
schedule
*
[4] Hoffman et al. 2013
**
[5] Hensman et al. 2013
[4,5,6,7,8] [6] Sheth & Khardon 2016b
[7] Salimans & Knowles 2013
fully local [2,3] [8] Li et al. (2015)
m
ap ini global [1]
pr -b M=N [9] Khan et al. 2018
ox at
im ch
at M=1 local VI * global + stochastic schedule = global
io
n ** fully local + mini-batch = fully local
Figure 8: Forms of stochastic approximation in the local VI framework and the relationship to previous
work. The granularity of the approximation of the approximation is controlled though M. Mini-batch
approximation may be used inside each local variational free-energy. A stochastic schedule can be
used to randomize the order in which groups of data points are visited. All algorithms have the same
fixed-points (the mini-batch approximation learning rate ρ has to obey the Robbin’s Munro conditions).
(due to covariate or data set shift over time). Second, general online learning does not allow all old data
points to be revisited, demanding incremental updates instead. This means that when a new batch of
data is received we must iterate on just these data to refine the approximate posterior, before moving on
and potentially never returning. Iterating 33 on this new data is possible, but it would have disastrous
consequences as it again breaks the iid mini-batch assumption and would just result in q fitting the
new data and forgetting the old previously seen data. A single update could be made, but this will
normally mean that the approach is data-inefficient and slow to learn. Iterating the local updates, on
the other hand, 34 or 35 works nicely as these naturally incorporate a deletion step that removes just
the contribution from the current mini-batch and can, therefore, be iterated to our heart’s content.
Similar arguments can be made about the distributed setting too. Stochastic global updates could
be used in the distributed setting, with each worker querying a server for the current natural parameters
ηq , computing ∆ηl = dµd q Eq (log p(y l |θ)) − ηlike /L and communicating this to a server. The server’s
(new) (old)
role is to update the global parameters ηq = ηq + ∆ηl and to send these new parameters to
the workers. The difficulty is that this setup must obey the iid assumption so i) the data must be
distributed across the workers in an iid way, and ii) the updates must be returned by each worker with
the same regularity. In contrast, the stochastic local updates can be used in a very similar way without
these restrictions.
The stochastic global update 32 does have two important advantages over the local updates 36.
First, the memory footprint is L times smaller only requiring a single set of natural parameters to be
maintained rather than M of them. Second, it can be faster to converge when the mini-batches are iid.
Contrast what happens when new data are seen for the first time in the two approaches. For simplicity
(i−1)
assume ρ0 = ρ = 1. In the second approach, ηm = 0 as the data have not been seen before, but
(i−1) (i−1)
the first approach effectively uses ηm = ηlike /L. That is, the first approach effectively estimates
the approximate likelihood for new data, based on those for previously seen data. This is a sensible
approximation for homogeneous mini-batches. A consequence of this is that the learning rate ρ0 can be
much larger than ρ (potentially greater than unity) resulting in faster convergence of the approximate
posterior. It would be interesting to consider modifications of the local updates (36) that estimate the
mth approximate likelihood based on information from all other data partitions. For example, in the
first pass through the data, the approximate likelihoods for unprocessed mini-batches could be updated
to be equal to the last approximate likelihood or to the geometric average of previous approximate
likelihoods. Alternatively, ideas from inference networks could be employed for this purpose.
42
A.7 The relationship between natural gradient, mirror-descent, trust-region and
proximal methods
Each step of gradient ascent of parameters η on a cost C can be interpreted as the result of an
optimization problem derived from linearizing the cost function around the old parameter estimate,
C(η) ≈ C(η (i−1) )+∇η C(η (i−1) )(η−η (i−1) ) where ∇η C(η (i−1) ) = dC(η)
dη |η=η (i−1) and adding a soft constraint
on the norm of the parameters:
dC(η) 1
η (i) = η (i−1) + ρ ⇔ η (i) = argmax ∇η C(η (i−1) )η − ||η − η (i−1) ||22
dη η 2ρ
Here the terms in the linearized cost that do not depend on η have been dropped as they do not effect
the solution of the optimization problem. The linearization of the cost ensures that there is an analytic
solution to the optimization and the soft constraint ensures that we do not move too far from the
previous setting of the parameters into a region where the linearization is inaccurate.
This reframing of gradient ascent reveals that it is making an Euclidean assumption about the
geometry parameter space and suggests generalizations of the procedure that are suitable for different
geometries. In our case, the parameters are natural parameters of a distribution and measures of
proximity that employ the KL divergence are natural.
The main result of this section is that the following optimization problems:
1
KL proximal method: ηq(i) = argmax ∇η F (i) (q (i−1) (θ))ηq − KL q(θ) k q (i−1) (θ) (37)
ηq ρ
KL trust region: ηq(i) = argmax ∇η F (i) (q (i−1) (θ))ηq s.t. KL q(θ) k q (i−1) (θ) ≤ γ (38)
ηq
KLs trust region: ηq(i) = argmax ∇η F (i) (q (i−1) (θ))ηq s.t. KLs q(θ) k q (i−1) (θ) ≤ γ (39)
ηq
1
mirror descent: µ(i)
q = argmax ∇ µ F (i) (i−1)
(q (θ))µ q − KL q (i−1)
(θ) k q(θ) (40)
µq ρ
All yield the same updates as the damped fixed point equations / natural gradient ascent:
(i) (i−1) d
ηbi = (1 − ρ)ηbi +ρ (i−1)
Eq(i−1) (log p(y bi |θ)) (41)
dµq
#−1
(i−1)
"
(i−1) dµq d
= (1 − ρ)ηbi +ρ (i−1) (i−1)
Eq(i−1) (log p(y bi |θ)). (42)
dηq dηq
In the first three cases (37 - 39), this equivalence only holds exactly in the general case if the parameter
(i) (i) (i−1)
changes ∆ηq = ηq − ηq are small.
Here the KL proximal method is the straightforward generalization of the gradient ascent example
that replaces the Euclidean norm by the exclusive KL divergence. The KL trust region method
uses a hard constraint on the same KL instead, but rewriting this as a Lagrangian recovers the KL
proximal method with 1/ρ being the Lagrange multiplier. The KLs trust region method, often used
to justify natural gradient ascent, employs the symmetrized KL divergence instead of the exclusive
KL. The symmetrized KL is the average of the exclusive and inclusive KLs, KLs q(θ) k q (i−1) (θ) =
1 (i−1) (θ) + 1 KL q (i−1) (θ) k q(θ) . Mirror descent, in its most general form, uses a
2 KL q(θ) k q 2
Bregman divergence to control the extent to which the parameters change rather than a KL divergence.
However, when applied to exponential families, the Bregman divergence becomes the inclusive KL
divergence yielding the form above [Raskutti and Mukherjee, 2015, Khan et al., 2016, Khan and
Lin, 2018]. Note that this last method operates in the mean parameter space and the equivalence is
43
attained by mapping the mean parameter updates back to the natural parameters. Mirror descent has
the advantage of not relying on the small parameter change assumption to recover natural gradient
ascent. Having explained the rationale behind these approaches we will now sketch how they yield the
fixed-point updates.
The equivalence of the KL proximal method can be shown by differentiating the cost wrt ηq and
substituting in the following expressions:
d dµq(i−1) (i−1)
∇η F (i) (q (i−1) (θ)) = Eq(i−1) (log p(y bi |θ)) −
η ,
dηq(i−1) dηq(i−1) bi
dKL q(θ) k q (i−1) (θ)
dµ (i−1)
dµq q
= ηq − ηq(i−1) ≈ ηq − ηq(i−1) .
dηq dηq dηq(i−1)
In the second line above the approximation results from the assumption of small parameter change
(i)
∆ηq (or alternatively local constancy of the Fisher information). Equating the derivatives to zero and
rearranging recovers the fixed point equations.
The equivalence of the KL trust region method is now simple to show as the associated Lagrangian,
L(ηq ) = ∇η F (i) (q (i−1) (θ))ηq − ρ1 KL q(θ) k q (i−1) (θ) − γ , is the proximal method up to an additive
constant.
The KLs trust region method can also be rewritten as a Lagrangian L(ηq ) = ∇η F (i) (q (i−1) (θ))ηq −
1 s (i−1)
ρ KL q(θ) k q (θ) − γ . For small changes in the approximate posterior natural parameters
(i) (i) (i−1)
∆ηq = ηq − ηq , the symmetrized KL can be approximated using a second order Taylor expansion,
1 > dµ (i−1)
q
KLs q(θ) k q (i−1) (θ) ≈ ηq − ηq(i−1) ηq − ηq(i−1) . (43)
2 dηq(i−1)
This is the same form as the exclusive KL takes, the inclusive and exclusive KL divergences being
locally identical around their optima. Taking derivatives of the Lagrangian and setting them to zero
recovers the fixed point equations again.
The mirror descent method can be shown to yield the fixed points by noting that
d
∇µ F (i) (q (i−1) (θ)) = E (i−1) (log p(y bi |θ)) − ηbi ,
dµq(i−1) q
dKL q (i−1) (θ) k q(θ)
= ηq − ηq(i−1) .
dµq
Differentiating the mirror descent objective and substituting these results in recovers usual update. The
last result above can be found using convex duality. For a full derivation and more information on the
relationship between mirror descent and natural gradients see Raskutti and Mukherjee [2015].
It is also possible to define optimization approaches analogous to the above that do not linearize the
free-energy term and instead perform potentially multiple updates of the nested non-linear optimization
problems [Theis and Hoffman, 2015, Khan et al., 2015, 2016].
44
(i)
2. qα (θ) = proj(pα (θ)) % moment match
1−1/α
3. q (i) (θ) = q (i−1) (θ) (qα (θ))1/α % update posterior
(i) q (i) (θ) (i−1)
4. tbi (θ) = q (i−1) (θ)
tbi (θ) % update approximate likelihood
45
Plug this into (47), we obtain:
d p(ybi |θ)
Z
log proj(p(i)
α (θ)) = log q (i−1)
(θ) + αT (θ) |
q (i−1) (θ) log (i−1)
dθ + F (α). (55)
dµq(i−1) tbi (θ)
Note that:
d d
Z Z
(i−1) (i−1) (i−1)
T (θ)|
q (θ) log tbi (θ)dθ = log tbi (θ) T (θ)| q (i−1) (θ)dθ (60)
dµq(i−1) dµq(i−1)
(i−1) dµq(i−1)
= log tbi (θ) (61)
dµq(i−1)
(i−1)
= log tbi (θ). (62)
d dµq d 1
Z Z
q̄(θ) log p(y bi |θ)dθ = q(θ) log p(y bi |θ)dθ (66)
dµq̄ dµq̄ dµq Zq
1 d
Z
= (Zq I) q(θ) log p(y bi |θ)dθ (67)
Zq dµq
d
Z
= q(θ) log p(y bi |θ)dθ. (68)
dµq
46
B Gradients of the free-energy with respect to hyperparameters
In this section, we derive the gradients of the global free energy wrt the hyperparameters when the
local VI procedure has converged to the optimal approximate posterior (either through analytic, off-the-
shelf or fixed-point optimization). We provide two derivations: First, the standard one which is specific
to approximate posteriors that are in the exponential family. Second, a derivation that applies for
general q(θ) which also provides more insight.
Note that,
Hence,
X X dA(ηq )
F1 = A(ηq ) − A(η0 ) − ηn| Eq (θ)[T (θ)] = A(η0 ) − A(ηq ) − ηn| .
m m
dηq
Differentiating F1 and F2 wrt a hyperparameter of the model, noting that the natural gradients of
the local factors and the global variational approximation both depend on , gives,
dA(ηq ) | dηq dA(η0 ) | dη0 X dA(ηq ) | dηm 2
dF1 | d A(ηq ) dηq
= − − − ηm
d dηq d dη0 d m
dηq d dηq dηq d
∂F2m | dηq
dF2 X ∂F2m
= +
d m
∂ ∂ηq d
P
Note that ηq = η0 + m ηm , leading to
X dηm dηq dη0
− =− + , (69)
m
d d d
Here,
∂F2m ∂
= Eq (log p(y bi |θ)). (70)
∂ηq ∂ηq
and that at convergence 31,
∂F2m dµq d2 A(ηq )
= ηq = ηm . (71)
∂ηq dηq dηq dηq
47
Therefore,
dA(ηq )
where µq = dηq = Eq(θ) [T (θ)] are the mean parameters.
N
Y
q(θ; ψ) = tn (θ; ψ). (74)
n=0
Here ψ are the variational parameters (these may correspond to natural parameters. Again the scale
of the approximate terms tn (θ; ψ) will be set such that q(θ; ψ) is normalized. Note this is a more
general form of approximate posterior that allows the prior to be approximated if it lies outside of the
variational family Q. If the prior lies within the variational family, the local updates will automatically
set it equal to the prior recovering the treatment in the rest of the paper and meaning that the results
presented here will still hold.
The global free-energy depends on the model hyperparameters through the joint distribution
p(y, θ|),
Z
p(y, θ|)
F(, q(θ; ψ)) = dθ q(θ; ψ) log (75)
q(θ; ψ)
Now consider the optimal variational approximation for a fixed setting of the hyperparameters,
Z
p(y, θ|)
ψ opt () = argmax dθ q(θ; ψ) log . (76)
ψ q(θ; ψ)
The collapsed variational bound can therefore be denoted, F(, q(θ; ψ opt ())) and it is this that we will
optimize to find the hyperparameters. Before we do so, note that we have been careful to represent
the two distinct ways that the free-energy depends on the hyperparameters, i) through the log-joint’s
dependence, ii) through the optimal approximate posterior’s implicit dependence via ψ opt ()). In fact we
can decouple these two contributions and consider evaluating the free-energy when the hyperparameters
differ, F(, q(θ; ψ opt (0 ))), the collapsed bound being recovered when 0 = .
We are now in a position to compute derivatives of the collapsed free-energy using the insight above
to split this into two terms,
d d d
F(, q(θ; ψ opt ())) = F(, q(θ; ψ opt (0 ))) + F(, q(θ; ψ opt (0 ))) . (77)
d d 0 = d0 0 =
48
We now consider these two terms: First, the dependence through the log-joint distribution,
d d
Z
F(, q(θ; ψ opt (0 ))) = dθ q(θ; ψ(0 )) log p(y, θ|)
d 0
= d 0 =
M
d d
X Z Z
= dθ q(θ; ψ(0 )) log p(ym |θ, ) + dθ q(θ; ψ(0 )) log p(θ|)
d 0 = d 0 =
m=1
M
X d d
= Eq(θ) log p(ym |θ, ) + Eq(θ) log p(θ|) . (78)
d d
m=1
Second, the dependence through the optimal approximate posterior’s implicit dependence on
d dψ opt (0 ) d
F(, q(θ; ψ opt (0 ))) = F(, q(θ; ψ) = 0. (79)
d0 0 = d0 dψ 0 =,ψ=ψ opt (0 )
Here we have substituted in the fact that were are at the collapsed bound and so the derivative wrt ψ
is zero.
So the term that arises from the dependence of the approximate posterior on the hyperparameters
(terms 2) vanishes meaning the only contribution comes from the first term. This is precisely the
same term that would remain if we were to perform coordinate ascent (since then when updating the
hyperparameters the approximate posterior would have been fixed).
M
d opt
X d d
F(, q(θ; ψ ())) = Eq(θ;ψopt ()) log p(ym |θ, ) + Eq(θ;ψopt ()) log p(θ|) . (80)
d d d
m=1
When the prior distribution is in the exponential family, the second term above becomes
d dη0
Eq(θ;ψopt ()) log p(θ|) = (µq − µ0 )| . (81)
d d
This recovers the expression in the previous section, although we have not assumed the approximate
posterior is in the exponential family (here µq and µ0 are the average of the prior’s sufficient statistics
under the approximate posterior and the prior respectively).
Figure 9 provides some intuition for these results. Note that in the case where the approximating
family includes the true posterior distribution, the collapsed bound is equal to the log-likelihood of the
hyperparameters. So, the result shows that the gradient of the log-likelihood wrt the hyperparameters
is equal to the gradient of the free-energy wrt the hyperparameters, treating q as fixed. Often this
is computed in the M-step of variational EM, but it is used in coordinate ascent, which can be slow
to converge. Instead, this gradient can be passed to an optimizer to perform direct gradient-based
optimization of the log-likelihood.
49
=0
Figure 9: Contours of the free-energy F(, q(θ; ψ)) are shown in green as a function of the hyper-
parameters and the variational parameters of the approximate posterior φ. The collapsed bound
F(, q(θ; ψ opt ())) is shown in blue. The gradients of the free-energy with respect to the variational
d
parameters are zero along the collapsed bound dψ F(, q(θ; ψ))|ψ=ψopt = 0, by definition. This means
that the gradients of the collapsed free-energy as a function of the hyperparameters are equal to those
d d
of the free-energy itself, d F(, q(θ; ψ)) = d F(, q(θ; ψ opt ())).
where R is the number of batches considered and tr (u) is the approximate contribution of the r-th
batch to the posterior. This is the standard set up considered for sparse GPs in the literature, see
e.g. Hensman et al. [2013], Bui et al. [2017b]. We next detail the specifics for the streaming settings
[Bui et al., 2017a], when we allow the pseudo-points to move and adjust the hyperparameters as new
data arrive.
Let a = f (zold ) and b = f (znew ) be the pseudo-outputs or inducing points before and after seeing
new data, where zold and znew are the pseudo-inputs accordingly. Note that extra pseudo-points can be
added or conversely, old pseudo-points can be removed, i.e. the cardinalities of a and b do not need to
be the same. The previous posterior, qold (f ) = p(f6=a |a, θold )q(a), can be used to find the approximate
likelihood given by old observations as follows,
Note that we have made the dependence of the hyperparameters explicit, as these will be optimized,
together with the variational parameters, using the variational free-energy. Substituting the approximate
likelihood above into the posterior that we want to target gives us:
p(f |θnew )p(yold |f )p(ynew |f ) p(f |θnew )qold (f )p(yold |θold )p(ynew |f )
p(f |yold , ynew ) = ≈ . (84)
p(ynew , yold |θnew ) p(f |θold )p(ynew , yold |θnew )
The new posterior approximation takes the same form as with the previous posterior, but with
the new pseudo-points and new hyperparameters: qnew (f ) = p(f6=b |b, θnew )q(b). This approximate
50
posterior can be obtained by minimizing the KL divergence,
p(f6=b |b, θnew )qnew (b)
Z
KL[qnew (f )||p̂(f |yold,new )] = df qnew (f ) log p(y |θ ) qold (f )
(85)
old old
p(ynew ,yold |θnew ) p(f |θnew )p(ynew |f ) p(f |θold )
Z2 (θnew ) p(a|θold )qnew (b)
Z
= log + df qnew (f ) log . (86)
Z1 (θold ) p(b|θnew )qold (a)p(ynew |f )
The last equation above is obtained by noting that p(f |θnew )/p(f6=b |b, θnew ) = p(b|θnew ) and
qold (f ) hh|a,
p(f
h
θold )qold (a) qold (a)
6=a hh
= hhh h h = .
p(f |θold ) p(f6=a |a,
hh θold
hh )p(a|θ old ) p(a|θ old )
Since the KL divergence is non-negative, the second term in (86) is the negative lower bound of
the approximate online log marginal likelihood5 , or the variational free energy, F(qnew (f )). We can
decompose the bound as follows,
p(a|θold )qnew (b)
Z
F(qnew (f ), θnew ) = df qnew (f ) log
p(b|θnew )qold (a)p(ynew |f )
Z
= KL(q(b)||p(b|θnew )) − df qnew (f ) log p(ynew |f )
qold (a)
Z
− daqnew (a) log .
p(a|θold )
The first two terms form the variational free-energy as if the current batch is the whole training data, and
the last term constrains the posterior to take into account the old likelihood (through the approximate
posterior and the prior).
C.2 Online variational free-energy approach using private pseudo-points with ap-
proximate maximum likelihood learning for the hyperparameters
Instead of using a common set of pseudo-points for all data points or streaming batches, we can assign
separate pseudo-points to each batch of data points as follows,
R
Y R
Y
p(f |y) ∝ p(f ) p(yr |f ) ≈ p(f ) tr (ur ) ∝ p(f6=u |u)q(u) = q(f ), (87)
r=1 r=1
where ur are the pseudo-points private to the r-th batch. As new data arrives, new pseudo-points
will be added to summarize the new data, and the old pseudo-points, corresponding to the previously
seen batches, will remain unchanged. This means we only need to add and adjust new pseudo-points
and the new likelihood approximation for the new data points, as opposed to all pseudo-points and all
corresponding likelihood approximations as done in the previous section.
Similar to the online learning the previous section, we will try to approximate the running posterior
in eq. (84),
p(f |θnew )qold (f )p(ynew |f ) p(yold |θold )
p(f |yold , ynew ) ≈ , (88)
p(f |θold ) p(ynew , yold |θnew )
where
51
and a represents all pseudo-points used for previous batches. Let b be the new pseudo-points for the
new data and tb (b) be the contribution of the new data points ynew towards the posterior. The new
approximate posterior is assumed to take the following form,
where we have chosen q(b|a) ∝ p(b|a, θnew )tb (b) and made the dependence on the hyperparameters
θnew implicit. Note that q(a) is the variational distribution over the previous pseudo-points, and such,
we only need to parameterize and learn the conditional distribution q(b|a).
Similar to the previous section, writing down the KL divergence from the running posterior in
eq. (88) to the approximate posterior in eq. (89), and ignoring constant terms result in the online
variational free-energy as follows,
p(f |θnew )qold (f )p(ynew |f )
Z
F(qnew (f ), θnew ) = df qnew (f ) log . (90)
p(f |θold )qnew (f )
Note that,
qold (f )
h
p(fhh|a, θold )q(a) q(a)
6=a hh
= hhh hh = , (91)
p(f |θold ) p(f6=a |a,h
h θold
hh )p(a|θold ) p(a|θold )
hhhh
p(f |θnew ) p(f6=a |a,h h)p(a,
θold hhh b|θ
hnew ) p(a, b|θnew )
= hhh hh = . (92)
qnew (f ) p(f6=a |a,
hh θold
hh )q(a, b) q(a, b)
This leads to,
Z
F(qnew (f ), θnew ) = −KL[q(a, b)||p(a, b|θnew )] + df qnew (f ) log p(ynew |f )
Z
− H[q(a))] + daq(a) log p(a|θθ ). (93)
Note again that we are only optimizing the variational parameters of q(b|a) and the hyperparameters,
and keeping q(a) fixed.
C.3 Online variational free-energy approach for both hyperparameters and the
latent function with shared pseudo-points
The variational approaches above, while maintaining an online distributional approximation for the
latent function, only retain a point estimate of the hyperparameters. Imagine having observed the first
batch of data points in a regression task and trained the model on this batch, and that the second
batch contains only one data point. In this case, maximum likelihood learning of the hyperparameters
will tend to give very large observation noise, i.e. the noise is used to solely explain the new data and
the latent function is largely ignored. Using the new model with the newly obtained hyperparameters
will thus result in poor predictions on previously seen data points.
We attempt to address this issue by maintaining a distributional approximation for the hyperparame-
ters, as well as one for the latent function, and adjusting these approximations using variational inference
as new data arrive. In particular, extending appendix C.1 by introducing a variational approximation
over the hyperparameters gives,
52
The likelihood of previously seen data points can be approximated via the approximate posterior as
follows,
Similar to the previous section, the online variational free-energy can be obtained by applying Jensen’s
inequality to the online log marginal likelihood, or by writing down the KL divergence as follows,
Most terms in the variational free-energy above requires computing an expectation wrt the variational
approximation q(θ), which is not available in closed-form even when q(θ) takes a simple form such as a
diagonal Gaussian. However, these expectations can be approximated by simple Monte Carlo with the
reparameterization trick [Kingma and Welling, 2014, Rezende et al., 2014]. As in previous section, all
other expectations can be handled tractably, either in closed-form or by using Gaussian quadrature.
C.4 Online variational free-energy approach for both hyperparameters and the
latent function with private pseudo-points
As in appendix C.2, new pseudo-points can be allocated to new data as they arrive, and the current
pseudo-points and their marginal variational approximation will remaine fixed. The corresponding
variational approximation for both the latent function and hyperparameters are:
The new approximate posterior above can be derived by approximating the likelihood factor of the
new data in the running posterior as follows,
where {ti }4i=1 are the approximate factors representing the contribution of the conditional prior
and the likelihood to the running posterior. In other words, q(b|a) ∝ t1 (b|a)t3 (b) and qnew (θ) ∝
53
qold (θ)t2 (θ)t4 (θ). Substituting the above variational approximation to the online variational free-energy
gives us,
p(f6=a,b |a, b, θ)q(a)q(b|a)qnew (θ)
Z
F(qnew (f, θ)) = df dθqnew (f, θ) log
p(f6=a |a, θ)q(a)qold (θ)p(ynew |f, θ)
Z
= KL[qnew (θ)||qold (θ)] + dθq(θ) (KL[q(a, b)||p(a, b|θ)])
Z Z
− dθq(θ) (KL[q(a)||p(a|θ)]) − df dθqnew (f, θ) log p(ynew |f, θ). (94)
Similar to the previous section, all terms the free-energy above can be tractably handled in closed-form
or by using simple Monte Carlo with the reparameterization trick [Kingma and Welling, 2014, Rezende
et al., 2014].
54
Figure 10: Results of the streaming GP experiment on a toy classification data set: the performance of
several batch and streaming methods after seeing all training points. In the batch case, we consider
three inference methods: MCMC for both the latent variable and the hyperparameters, VI for both the
latent function and the hyperparameters, and VI for the latent function and approximate maximum
likelihood learning for the hyperparameters. The two latter methods are also tested in the streaming
settings. We show the predictions made by the methods after training in the batch case, and after
seeing all three batches in the streaming case. The (distributional) hyperparameter estimates are also
included. Best viewed in colour.
where {xn , yn }N
n=1 are the training points, θ is the network weights and biases. However, getting the
exact posterior is analytically intractable and as such approximation methods are needed. In this
section, we discuss several approximation strategies for training a Bayesian neural network on the
standard MNIST ten-way classification data set. In particular, we focus on a case where data are
decentralized on different machines, that is we further assume that N training points are partitioned
into K = 10 disjoint memory shards. Furthermore, two levels of data homogeneity across memory
shards are considered: homogeneous [or iid, that is each shard has training points of all classes] and
55
Figure 11: Results of the streaming GP experiment on a toy classification data set: the performance of
the streaming methods after seeing each data batch. Two methods were considered: VI for both the
latent function and the hyperparameters, and VI for the latent function and approximate maximum
likelihood learning for the hyperparameters. We show the predictions made by the methods after seeing
each data batch and the corresponding (distributional) hyperparameter estimates. Best viewed in
colour.
inhomogeneous [or non-iid, i.e. each shard has training points of only one class].
We place a diagonal standard Normal prior over the parameters, p(θ) = N (θ; 0, I), and initialize
the mean of the variational approximations as suggested by Glorot and Bengio [2010]. For distributed
training methods, the data set is partitioned into 10 subsets or shards, and 10 compute nodes (workers)
with each able to access one memory shard. The implementation of different inference strategies was
done in Tensorflow [Abadi et al., 2016]) and the workload between workers is managed using Ray
[Moritz et al., 2017].
56
Figure 12: Results of the streaming GP experiment on a toy classification data set: a failure case of
maximum likelihood learning for the hyperparameters. Two methods were considered: VI for both the
latent function and the hyperparameters, and VI for the latent function and approximate maximum
likelihood learning for the hyperparameters. We show the predictions made by the methods after seeing
each data batch and the corresponding (distributional) hyperparameter estimates. Best viewed in
colour.
E.1 Global VI
We first considered global variational inference, as described in section 5, for getting an approximate
posterior over the parameters. The variational lower bound (eq. (18)) is optimized using Adam [Kingma
and Ba, 2014]. We considered one compute node (with either one core or ten cores) that can access
the entire data set, and simulates the data distribution by sequentially showing mini-batches that can
potentially have all ten classes (iid) or that have data of only one class (non-iid). The full performance
on the test set during training for different learning rate hyperparameters of the Adam optimizer
are shown in figs. 13 and 14. Notice that in the iid setting, larger learning rates tend to yield faster
convergence but can give a slightly poorer predictive performance on the test set at the end of training
(see fig. 14 with a learning rate of 0.005). The non-iid is arguably more difficult and the performance
57
can oscillate if the learning rate is too large.
58
100 100
error /%
error /%
10 10
100 101
nll /nat
100 101 102 103 104 100 101 102 103 104
train time /s train time /s
(a) one compute node with one core and iid mini-batches (b) one compute node with one core and non-iid mini-batches
Figure 13: The performance of global VI on the test set in the iid [left] and non-iid [right] settings, when the compute node has only one core.
Different traces correspond to different learning rate hyperparameters of Adam.
60
(b) one compute node with ten cores and non-iid mini-batches
(a) one compute node with ten cores and iid mini-batches
Figure 14: The performance of global VI on the test set in the iid and non-iid settings, when the compute node has ten cores. Different traces
correspond to different learning rate hyperparameters of Adam.
100
100
error /%
error /%
10
10
lrate=0.0010
lrate=0.0010
lrate=0.0050
lrate=0.0050 2 lrate=0.0100
2 lrate=0.0100
61
100
100
nll /nat
nll /nat
lrate=0.0010
lrate=0.0010
lrate=0.0050
lrate=0.0050
lrate=0.0100
lrate=0.0100
100 101 102 103 104
100 101 102 103 104
train time /s
train time /s
(b) BCM with the prior N (0, 1) being split equally across 10 workers and iid
(a) BCM with the same N (0, 1) prior across 10 workers and iid data
data
Figure 15: Performance of BCM with two prior sharing strategies on the iid setting, for various learning rates. Best viewed in colour.
100
100
error /%
error /%
10
10
lrate=0.0010
lrate=0.0010 lrate=0.0050
lrate=0.0050 2 lrate=0.0100
2 lrate=0.0100
4 × 100
62
3 × 100
nll /nat
nll /nat
2 × 100
lrate=0.0010 lrate=0.0010
lrate=0.0050 lrate=0.0050
lrate=0.0100 lrate=0.0100
100 101 102 103 104 100 101 102 103 104
train time /s train time /s
(a) BCM with the same N (0, 1) prior across workers and non-iid data (b) BCM with the prior N (0, 1) being split equally across workers and non-iid
data
Figure 16: Performance of BCM with two prior sharing strategies on the non-iid setting, for various learning rates. Best viewed in colour.
100
lrate=0.001 lrate=0.001
100
error /%
nll /nat
10
no epochs = 1 no epochs = 50 no epochs = 1 no epochs = 50
no epochs = 10 no epochs = 100 no epochs = 10 no epochs = 100
2
no epochs = 20 10−1 no epochs = 20
100
lrate=0.005 lrate=0.005
100
error /%
nll /nat
10
10−1
63
100
lrate=0.010 lrate=0.010
100
error /%
nll /nat
10
10−1
2
100 101 102 103 104 100 101 102 103 104
train time /s train time /s
Figure 17: Performance of sequential PVI with only one pass through all memory shards when the data are iid. The number of epochs for
each worker and the learning rate hyperparameter of Adam were varied. Best viewed in colour.
100
lrate=0.001
lrate=0.001
4 × 100
error /%
nll /nat
10
3 × 100
no epochs = 1 no epochs = 50 no epochs = 1 no epochs = 50
no epochs = 10 no epochs = 100 no epochs = 10 no epochs = 100
no epochs = 20 no epochs = 20
2
100
lrate=0.005 101 lrate=0.005
error /%
6 × 100
nll /nat
10
4 × 100
3 × 100
64
100
lrate=0.010 lrate=0.010
101
nll /nat
error /%
6 × 100
10
4 × 100
3 × 100
2
100 101 102 103 104
100 101 102 103 104
train time /s
train time /s
Figure 18: Performance of sequential PVI with only one pass through all memory shards when the data are non-iid. The number of epochs for
each worker and the learning rate hyperparameter of Adam were varied. Best viewed in colour.
100
damping=0.950 damping=0.950
100
error /%
nll /nat
10
100
damping=0.900 damping=0.900
100
error /%
nll /nat
10
10−1
65
100
damping=0.800 damping=0.800
100
error /%
nll /nat
10
10−1
2
Figure 19: Performance of PVI with synchronous updates when the data are iid. In this experiment, each worker communicates with the
central server after one epoch. The learning rate hyperparameter of Adam and the damping factor were varied. Best viewed in colour.
100
damping=0.950
2 × 100 damping=0.950
error /%
nll /nat
100
10
6 × 10−1
lrate=0.0001 lrate=0.0010
lrate=0.0001 lrate=0.0010
lrate=0.0005 lrate=0.0050 4 × 10−1
2 lrate=0.0005 lrate=0.0050
3 × 10−1
100
damping=0.900 damping=0.900
error /%
nll /nat
10 100
66
100 101
damping=0.800 damping=0.800
nll /nat
error /%
10
100
2
100 101 102 103
100 101 102 103 train time /s
train time /s
(b) NLL
(a) Error
Figure 20: Performance of PVI with synchronous updates when the data are non-iid. In this experiment, each worker communicates with the
central server after one epoch. The learning rate hyperparameter of Adam and the damping factor were varied. Best viewed in colour.
102
101
error /%
100
10−1
10−2 test
train
100
nll /nat
10−1
10−2
Figure 21: For certain hyperparameter settings, PVI with synchronous updates worryingly exhibits
over-fitting.
67
100
damping=0.950 damping=0.950
100
error /%
nll /nat
10
100
damping=0.900 damping=0.900
100
error /%
nll /nat
10
68
100
damping=0.800 101 damping=0.800
error /%
nll /nat
10 100
Figure 22: Performance of PVI with asynchronous, lock-free updates when the data are iid. In this experiment, each worker communicates
with the central server after one epoch. The learning rate hyperparameter of Adam and the damping factor were varied. Best viewed in colour.
100
damping=0.950
damping=0.950
2 × 100
error /%
nll /nat
10
100
lrate=0.0001 lrate=0.0010
lrate=0.0001 lrate=0.0010
lrate=0.0005 lrate=0.0050
2 6 × 10−1 lrate=0.0005 lrate=0.0050
100
damping=0.900 damping=0.900
2 × 100
error /%
nll /nat
10 100
6 × 10−1
69
100
damping=0.800 4 × 100 damping=0.800
3 × 100
2 × 100
nll /nat
error /%
10
100
6 × 10−1
2
100 101 102 103
100 101 102 103 train time /s
train time /s
(b) NLL
(a) Error
Figure 23: Performance of PVI with asynchronous, lock-free updates when the data are non-iid. In this experiment, each worker communicates
with the central server after one epoch. The learning rate hyperparameter of Adam and the damping factor were varied. Best viewed in colour.
E.6 Stochastic Natural Gradient Variational Inference for Bayesian Neural Net-
works
In this experiment, we stress-test various optimization methods for global variational inference for
Bayesian neural networks. In particular, we consider two methods: (i) stochastic natural-gradient global
VI with a fixed learning rate (SNGD, see eq. (13)), and (ii) stochastic flat gradient global VI with an
adaptive learning rate provided by Adam [Kingma and Ba, 2014]. Two Bayesian neural networks with
one hidden layer of 200 or 500 Relu hidden units, and the standard MNIST ten-class classification
problem are employed for this experiment. The network is trained using mini-batches of 200 data
points and 800 or 1000 epochs. Both optimization methods considered have similar running time. The
full results are included in figs. 24, 25, 27 and 28 and key results are shown in figs. 26 and 29. It
can be noticed from figs. 26 and 29 that the best versions of SNDG and Adam perform similarly in
terms of both classification errors and convergence speed/data efficiency. However, both methods do
require tuning of the learning rate hyperparameter. As already observed in the global VI experiment in
section 7.1, signs of fast convergence early during training when using Adam do not necessarily result
in a good predictive performance at the end.
As mentioned in the main text, while natural gradients has been shown to be effective in the batch,
global VI settings [Honkela et al., 2010], the result presented here could be seen as a negative result
for natural-gradient based methods — a stochastic natural-gradient/fixed-point method with fixed
learning rate does not outperform an adaptive stochastic flat-gradient method. However, it might not
be surprising as Adam adjusts its step-sizes based on approximate second-order information of the
objective. This also suggests a future research venue to develop effective adaptive optimization schemes
for stochastic natural-gradient variational inference.
70
102
train, adam, lr=0.001 train, sngd, lr=0.0001
test, adam, lr=0.001 test, sngd, lr=0.0001
train, adam, lr=0.005 train, sngd, lr=0.00015
test, adam, lr=0.005 test, sngd, lr=0.00015
train, adam, lr=0.01 train, sngd, lr=0.0002
test, adam, lr=0.01 test, sngd, lr=0.0002
train, adam, lr=0.02 train, sngd, lr=0.00025
test, adam, lr=0.02 test, sngd, lr=0.00025
train, adam, lr=0.03 train, sngd, lr=0.0003
error /%
Figure 24: Classification error rates on the train and test sets during training using Adam and Stochastic
Natural Gradient (SNGD) methods on the MNIST classification task with a Bayesian neural network
with one hidden layer of 200 rectified linear units. The final performance of all settings are shown in
fig. 26. For both Adam and SNGD, the performance highly depends on the learning rate, but the best
learning rates for both methods give similar train and test results and yield similar convergence. Note
that while Adam adaptively changes the learning rate based on the gradient statistics, SNGD employs
a fixed step size. See text for more details. Best viewed in colour.
71
train, adam, lr=0.001 train, sngd, lr=0.0001
test, adam, lr=0.001 test, sngd, lr=0.0001
train, adam, lr=0.005 train, sngd, lr=0.00015
test, adam, lr=0.005 test, sngd, lr=0.00015
train, adam, lr=0.01 train, sngd, lr=0.0002
100
test, adam, lr=0.01 test, sngd, lr=0.0002
train, adam, lr=0.02 train, sngd, lr=0.00025
test, adam, lr=0.02 test, sngd, lr=0.00025
train, adam, lr=0.03 train, sngd, lr=0.0003
nll /nat
10−1
Figure 25: Negative log-likelihoods on the train and test sets during training using Adam and Stochastic
Natural Gradient (SNGD) methods on the MNIST classification task with a Bayesian neural network
with one hidden layer of 200 rectified linear units. The final performance of all settings are shown in
fig. 26. For both Adam and SNGD, the performance highly depends on the learning rate, but the best
learning rates for both methods give similar train and test results and yield similar convergence. Note
that while Adam adaptively changes the learning rate based on the gradient statistics, SNGD employs
a fixed step size. See text for more details. Best viewed in colour.
72
73
Figure 26: Performance on the train set [left] and test set [right] after 1000 epochs using Adam and Stochastic Natural Gradient (SNGD)
methods on the MNIST classification task with a Bayesian neural network with one hidden layer of 200 rectified linear units, and the typical
performance traces as training progress [inset plots]. This figure summarizes the full results in figs. 24 and 25. The performance is measured
using the classification error [error] and the negative log-likelihood [nll], and for both measures, lower is better and, as such, closer to the
bottom left is better. For both Adam and SNGD, the performance highly depends on the learning rate, but the best learning rates for both
methods give similar train and test results and yield similar convergence. Note that while Adam adaptively changes the learning rate based
on the gradient statistics, SNGD employs a fixed step size. See text for more details. Best viewed in colour.
102
train, adam, lr=0.001 train, sngd, lr=0.0001
test, adam, lr=0.001 test, sngd, lr=0.0001
train, adam, lr=0.005 train, sngd, lr=0.00015
test, adam, lr=0.005 test, sngd, lr=0.00015
train, adam, lr=0.01 train, sngd, lr=0.0002
test, adam, lr=0.01 test, sngd, lr=0.0002
train, adam, lr=0.02 train, sngd, lr=0.00025
test, adam, lr=0.02 test, sngd, lr=0.00025
train, adam, lr=0.03 train, sngd, lr=0.0003
error /%
100
100 101 102 103
epoch
Figure 27: Classification error rates on the train and test sets during training using Adam and Stochastic
Natural Gradient (SNGD) methods on the MNIST classification task with a Bayesian neural network
with one hidden layer of 500 rectified linear units. The final performance of all settings are shown in
fig. 29. For both Adam and SNGD, the performance highly depends on the learning rate, but the best
learning rates for both methods give similar train and test results and yield similar convergence. Note
that while Adam adaptively changes the learning rate based on the gradient statistics, SNGD employs
a fixed step size. See text for more details. Best viewed in colour.
74
train, adam, lr=0.001 train, sngd, lr=0.0001
test, adam, lr=0.001 test, sngd, lr=0.0001
train, adam, lr=0.005 train, sngd, lr=0.00015
test, adam, lr=0.005 test, sngd, lr=0.00015
100 train, adam, lr=0.01 train, sngd, lr=0.0002
test, adam, lr=0.01 test, sngd, lr=0.0002
train, adam, lr=0.02 train, sngd, lr=0.00025
test, adam, lr=0.02 test, sngd, lr=0.00025
train, adam, lr=0.03 train, sngd, lr=0.0003
nll /nat
10−1
Figure 28: Negative log-likelihoods on the train and test sets during training using Adam and Stochastic
Natural Gradient (SNGD) methods on the MNIST classification task with a Bayesian neural network
with one hidden layer of 500 rectified linear units. The final performance of all settings are shown in
fig. 29. For both Adam and SNGD, the performance highly depends on the learning rate, but the best
learning rates for both methods give similar train and test results and yield similar convergence. Note
that while Adam adaptively changes the learning rate based on the gradient statistics, SNGD employs
a fixed step size. See text for more details. Best viewed in colour.
75
76
Figure 29: Performance on the train set [left] and test set [right] after 800 epochs using Adam and Stochastic Natural Gradient (SNGD)
methods on the MNIST classification task with a Bayesian neural network with one hidden layer of 500 rectified linear units, and the
typical performance traces as training progress [inset plots]. This figure summarizes the full results in figs. 27 and 28. The performance
is measured using the classification error [error] and the negative log-likelihood [nll], and for both measures, lower is better and, as such,
closer to the bottom left is better. Full training and test performance results are included in the appendix. For both Adam and SNGD, the
performance highly depends on the learning rate, but the best learning rates for both methods give similar train and test results and yield
similar convergence. Note that while Adam adaptively changes the learning rate based on the gradient statistics, SNGD employs a fixed step
size. See text for more details. Best viewed in colour.