(2024-ICML) Variational Schrodinger Diffusion Models
(2024-ICML) Variational Schrodinger Diffusion Models
method for optimizing transportation plans in dif- et al., 2022; Kong et al., 2021; Ramesh et al., 2022; Zhang
fusion models. However, SB requires estimating et al., 2024). The key to their scalability lies in the closed-
the intractable forward score functions, inevitably form updates of the forward process, highlighting both sta-
resulting in the costly implicit training loss based tistical efficiency (Koehler et al., 2023) and diminished de-
on simulated trajectories. To improve the scal- pendence on dimensionality (Vono et al., 2022). Neverthe-
ability while preserving efficient transportation less, diffusion models lack a distinct guarantee of optimal
plans, we leverage variational inference to lin- transport (OT) properties (Lavenant & Santambrogio, 2022)
earize the forward score functions (variational and often necessitate costly evaluations to generate higher-
scores) of SB and restore simulation-free proper- fidelity content (Ho et al., 2020; Salimans & Ho, 2022; Lu
ties in training backward scores. We propose the et al., 2022; Xue et al., 2023; Luo, 2023).
variational Schrödinger diffusion model (VSDM), Alternatively, the Schrödinger bridge (SB) problem
where the forward process is a multivariate dif- (Léonard, 2014; Chen & Georgiou, 2016; Pavon et al., 2021;
fusion and the variational scores are adaptively Caluya & Halder, 2022; De Bortoli et al., 2021), initially
optimized for efficient transport. Theoretically, rooted in quantum mechanics (Léonard, 2014), proposes
we use stochastic approximation to prove the con- optimizing a stochastic control objective through the use
vergence of the variational scores and show the of forward-backward stochastic differential equations (FB-
convergence of the adaptively generated samples SDEs) (Chen et al., 2022b). The alternating solver gives rise
based on the optimal variational scores. Empiri- to the iterative proportional fitting (IPF) algorithm (Kull-
cally, we test the algorithm in simulated examples back, 1968; Ruschendorf, 1995) in dynamic optimal trans-
and observe that VSDM is efficient in genera- port (Villani, 2003; Peyré & Cuturi, 2019). Notably, the
tions of anisotropic shapes and yields straighter intractable forward score function plays a crucial role in
sample trajectories compared to the single-variate providing theoretical guarantees in optimal transport (Chen
diffusion. We also verify the scalability of the et al., 2023c; Deng et al., 2024). However, it simultane-
algorithm in real-world data and achieve competi- ously sacrifices the simulation-free property and largely
tive unconditional generation performance in CI- relies on warm-up checkpoints for conducting large-scale
FAR10 and conditional generation in time series experiments (De Bortoli et al., 2021; Chen et al., 2022b). A
modeling. Notably, VSDM no longer depends on natural follow-up question arises:
warm-up initializations and has become tuning-
friendly in training large-scale experiments. Can we train diffusion models with efficient transport?
1
Variational Schrödinger Diffusion Models
Theoretically, we leverage stochastic approximation (Rob- Peluchetti (2023); Chen et al. (2023b) generalized bridge
bins & Monro, 1951) to demonstrate the convergence of matching and flow matching based EOT and obtained
the variational score to the optimal local estimators. Al- smoother trajectories, however, scalability remains a signifi-
though the global transport optimality is compromised, the cant concern for Schrödinger-based diffusions.
notable simulation-free speed-ups in training the backward
score render the algorithm particularly attractive for train- 3. Preliminaries
ing various generation tasks from scratch. Additionally, the
efficiency of simulation-based training for the linearized 3.1. Diffusion Models
variational score significantly improves owing to computa-
The score-based generative models (SGMs) (Ho et al., 2020;
tional advancements in convex optimization. We validate
Song et al., 2021b) first employ a forward process (1a) to
the strength of VSDM through simulations, achieving com-
map data to an approximate Gaussian and subsequently
pelling performance on standard image generation tasks.
reverse the process in Eq.(1b) to recover the data distribution.
Our contributions unfold in four key aspects:
d→−
x t = f t (→
−
x t )dt + βt d→
p −
• We introduce the variational Schrödinger diffusion wt (1a)
model (VSDM), a multivariate diffusion with optimal ←
− ← − ←− p ←−
d x t = f t ( x t ) − βt ∇ log ρt x t dt + βt d w t , (1b)
variational scores guided by optimal transport. Addi-
tionally, the training of backward scores is simulation- where ←− ,→
x −
x ∈ Rd ; →
t t
−
x ∼ ρ and ←
0
− ∼ ρ ; f de-
xdata T prior t
free and becomes much more scalable. notes the vector field and is often set to 0 (a.k.a. VE-SDE)
or linear in x (a.k.a. VP-SDE); βt > 0 is the time-varying
• We study the convergence of the variational score using scalar; →−
w t is a forward Brownian motion from t ∈ [0, T ]
stochastic approximation (SA) theory, which can be with ρT ≈ ρprior ; ←− is a backward Brownian motion from
w t
further generalized to a class of state space diffusion time T to 0. The marginal density ρt of the forward process
models for future developments. (1a) is essential for generating the data but remains inacces-
sible in practice due to intractable normalizing constants.
• VSDM is effective in generating data of anisotropic
shapes and motivates straighter transportation paths
Explicit Score Matching (ESM) Instead, the conditional
via the optimized transport.
score function ∇ log ρt|0 (·) ≡ ∇ log ρt ·|→
−
x 0 is estimated
• VSDM achieves competitive unconditional generation by minimizing a user-friendly ESM loss (weighted by λ)
on CIFAR10 and conditional generation in time series between the score estimator st ≡ sθ (·, t) and exact score
modeling without reliance on warm-up initializations. (Song et al., 2021b) such that
E λ E→ − E→ − →
→
−
− [∥s ( x ) − ∇ log ρ
→
−x ∥2 ] . (2)
t t x0 x t| x 0 t t t|0 t 2
2. Related Works Notably, both VP- and VE-SDEs yield closed-form expres-
sions for any →−
x t given →
−x 0 in the forward process (Song
Flow Matching and Beyond Lipman et al. (2023) utilized et al., 2021b), which is instrumental for the scalability of
the McCann displacement interpolation (McCann, 1997) to diffusion models in real-world large-scale generation tasks.
train simulation-free CNFs to encourage straight trajectories.
Consequently, Pooladian et al. (2023); Tong et al. (2023) Implicit Score Matching (ISM) By integration by parts,
proposed straightening by using minibatch optimal transport ESM is equivalent to the ISM loss (Hyvärinen, 2005; Huang
solutions. Similar ideas were achieved by Liu (2022); Liu et al., 2021; Luo et al., 2024b) and the evidence lower bound
et al. (2023) to iteratively rectify the interpolation path. (ELBO) follows
Albergo & Vanden-Eijnden (2023); Albergo et al. (2023)
developed the stochastic interpolant approach to unify both log ρ0 (x0 ) ≥ EρT |0 (·) log ρT |0 (xT )
flow and diffusion models. However, “straighter” transport 1 T
Z h i
2
maps may not imply optimal transportation plans in general − Eρt|0 (·) βt ∥st ∥2 + 2∇ · (βt st − f t ) dt.
2 0
and the couplings are still not effectively optimized.
ISM is naturally connected to Song et al. (2020), which
supports flexible marginals and nonlinear forward processes
Dynamic Optimal Transport Finlay et al. (2020); Onken
but becomes significantly less scalable compared to ESM.
et al. (2021) introduced additional regularization through
optimal transport to enforce straighter trajectories in CNFs
and reduce the computational cost. De Bortoli et al. (2021); 3.2. Schrödinger Bridge
Chen et al. (2022b); Vargas et al. (2021) studied the dy- The dynamic Schrödinger bridge aims to solve a full bridge
namic Schrödinger bridge with guarantees in entropic opti-
mal transport (EOT) (Chen et al., 2023c); Shi et al. (2023); inf KL(P|Q), (3)
P∈D(ρdata ,ρprior )
2
Variational Schrödinger Diffusion Models
where D(ρdata , ρprior ) is the family of path measures with • The nonlinear diffusion no longer yields closed-form
marginals ρdata and ρprior at t = 0 and t = T , respectively;
√ Q expression of →
−
x t given →
−
x 0 (Chen et al., 2022b).
is the prior process driven by dxt = f t (xt )dt+ 2βt εd→ −
w t.
• The ISM loss is inevitable and the estimator suffers
It also yields a stochastic control formulation (Chen et al.,
from a large variance issue (Hutchinson, 1989).
2021; Pavon et al., 2021; Caluya & Halder, 2022).
Z T
4.1. Variational Inference via Linear Approximation
1 →
− 2
inf E ∥ut ( x t )∥2 dt
u∈U 0 2 FB-SDEs naturally connect to the alternating-projection
s.t. d−
→
x t = f t (−
→
x ) + βt ut (−
→
x ) dt + 2βt εd−
→
h p i p
wt (4) solver based on the IPF (a.k.a. Sinkhorn) algorithm, boiling
−
→ −
→ down the full bridge (3) to a half-bridge solver (Pavon et al.,
x ∼ρ , x ∼ρ
0 data T , prior
2021; De Bortoli et al., 2021; Vargas et al., 2021). With P1
where U is the family of controls. The expectation is taken given and k = 1, 2, ..., we have:
w.r.t →
−
ρ t (·), which denotes the PDF of the controlled dif- P2k := arg min KL(P∥P2k−1 ), (7a)
fusion (4); ε is the temperature of the diffusion and the P∈D(ρdata , ·)
regularizer in EOT (Chen et al., 2023c). P2k+1 := arg min KL(P∥P2k ). (7b)
P∈D(·, ρprior )
Solving the underlying Hamilton–Jacobi–Bellman (HJB)
equation and invoking the time reversal (Anderson, 1982) More specifically, Chen et al. (2022b) proposed a neural net-
with ε = 21 , Schrödinger system yields the desired work parameterization to model (← −
z t, →−
z t ) using (←
−
z θt , →
−
zωt ),
forward-backward stochastic differential equations (FB- where θ and ω refer to the model parameters, respectively.
SDEs) (Chen et al., 2022b): Each stage of the half-bridge solver proposes to solve the
−
→ → i models alternatingly as follows
d→
−x = f (− →
x ) + β ∇ log ψ (−x ) dt + β d− →
h p
t t t t t t w , (5a)t t Z T
←
−
←−θ − →ω −
→
d←− = f (← − ←− ← − β t d←
− , (5b) L (θ) = − x t ∽(5a) Γ1 ( z t ; z t )dt x 0 = x0
p
x t t x t ) − βt ∇ log φ t ( x t ) dt + w t E→
−
0
→
−
where ψ t (·)← − (·) = →
φ t
−ρ t (·), ρ0 (·) ∼ ρdata , ρT (·) ∼ ρprior . (8a)
T
→
−
Z
→
− −
→ω ← −θ ←
−
To solve the optimal controls (scores) (∇ log ψ , ∇ log ← −),
φ L (ω) = − x t ∽(5b) Γ1 ( z t ; z t )dt x T = xT ,
E←
−
0
a standard tool is to leverage the nonlinear Feynman-Kac (8b)
formula (Ma & Yong, 2007; Karatzas & Shreve, 1998; Chen
where Γ1 is defined in Eq.(6) and ∽ denotes the approxi-
et al., 2022b) to learn a stochastic representation.
mate simulation parametrized by neural networks *
Proposition 1 (Nonlinear Feynman-Kac representation).
Assume Lipschitz smoothness and linear growth condition However, solving the backward score in Eq.(8a) through
on the drift f and diffusion g in the FB-SDE (5). Define simulations, akin to the ISM loss, is computationally de-
→
− →
−
y t = log ψ t (xt ) and ← −
y t = log ← − (x ). Then the stochas-
φ t t
manding and affects the scalability in generative models.
tic representation follows To motivate simulation-free property, we leverage varia-
Z T tional inference (Blei et al., 2017) and study a linear approx-
←−
ys=E ← −
yT − Γζ (← −
z t; →
−
z t )dt →
−
x s = xs , →
−
imation of the forward score ∇ log ψ (x, t) ≈ At x with
s →
− 1 → −
f t ( x t ) ≡ − 2 βt x t , which ends up with the variational
1 − 2
Γζ (←
−
z t; →
−
z t )≡ ∥←
p ←
βt −
z t − f t + ζ⟨←
−
z t, −
→
z t ∥2 + ∇ · z t ⟩, FB-SDE (VFB-SDE):
2
(6)
1 →
d x t = − βt x t + βt At x t dt + βt d→
→
− − →
− p −
√ √ w t, (9a)
where →
−z = β ∇→
t
−y ,←
t
−
z = β ∇←
t t
−
y , and ζ = 1.
t t 2
d← − = −1β ← − − ←
→ − p ← −,
x t t x t − βt ∇ log ρ t ( x t ) dt + βt d w t
4. Variational Schrödinger Diffusion Models 2
(9b)
SB outperforms SGMs in the theoretical potential of optimal →
−
→
− where t ∈ [0, T ] and ∇ log ρ is the score function of (9a)
t
transport and an intractable score function ∇ log ψ t (xt ) is
and the conditional version is to be derived in Eq.(15).
exploited in the forward SDE for more efficient transporta-
tion plans. However, there is no free lunch in achieving such The half-bridge solver is restricted to a class of OU pro-
efficiency, and it comes with three notable downsides: cesses OU(ρdata , ·) with the initial marginal ρdata .
→
− arg min KL(P∥P2k−1 ) ⇒ arg min KL(P∥P
b 2k−1 ).
• Solving ∇ log ψ t in Eq.(5a) for optimal transport is P∈D(ρdata ,·) P∈OU(ρdata ,·)
b
prohibitively costly and may not be necessary (Mar-
zouk et al., 2016; Liu et al., 2023). *∼ (resp. ∽) denotes the exact (resp. parametrized) simulation.
3
Variational Schrödinger Diffusion Models
By the mode-seeking property of the exclusive (reverse) KL where µt|0 is defined in Eq.(13) and ϵ is the standard d-
divergence (Chan et al., 2022), we can expect the optimizer dimensional Gaussian vector. The score function follows
b to be a local estimator of the nonlinear solution in (7a).
P 1
ρ t|0 (→
∇ log →
− −
x t ) = − ∇[(→ −x t − µt )⊺ Σ−1 →
−
t|0 ( x t − µt )]
Additionally, the loss function (8b) to learn the variational 2
= −Σ−1 →
−
score At , where t ∈ [0, T ], can be simplified to t|0 ( x t − µt ) (15)
Z T = −L−⊺ −1 −⊺
t Lt Lt ϵ := −Lt ϵ.
→
− ←
− θ ←−
L (A) = − Ext ∽(9b) Γζ (At xt ; z t )dt x T = xT ,
0 Invoking the ESM loss function in Eq.(2), we can learn the
(10) score function ∇ log →−ρ t|0 (→
−
x t |→
−
x 0 ) using a neural network
where Γζ is defined in Eq.(6). Since the structure property parametrization st (·) and optimize the loss function:
→
− ← − =→ −
ψtφ t ρ t in Eq.(5) is compromised by the variational ∇A ∥L−⊺ 2
t ϵ − st (xt )∥2 . (16)
inference, we propose to tune ζ in our experiments.
One may further consider preconditioning techniques (Kar-
4.2. Closed-form Expression of Backward Score ras et al., 2022) or variance reduction (Singhal et al., 2023)
to stabilize training and accelerate training speed.
Assume a prior knowledge of At is given, we can rewrite
the forward process (9a) in the VFB-SDE and derive a mul-
Speed-ups via time-invariant and diagonal Dt If we
tivariate forward diffusion (Singhal et al., 2023):
parametrize Dt as a time-invariant and diagonal positive-
1 definite matrix, the formula (14) has simpler explicit ex-
d→
−x t = − βt I + βt A t → −
x t dt + βt d→
p −
wt
2 pressions that do not require calling matrix exponential
(11)
1 →
− p → − operators. We present such a result in Corollary 1. For the
= − Dt βt x t dt + βt d w t , image generation experiment in Section 7.3, we use such a
2
diagonal parametrization when implementing the VSDM.
where Dt = I − 2At ∈ Rd×d is a positive-definite matrix † .
Consider the multivariate OU process (11). The mean and Corollary 1. If Dt = Λ := diag(λ), where λi ≥ 0, ∀1 ≤
Rt
covariance follow i ≤ d. If we denote the σt2 := 0 βs ds, then matrices Ct
and Ht has simpler expressions with
dµt|0 1
= − βt Dt µt|0 (12a) 1 1
dt 2 Ct = Λ−1 exp( σt2 Λ) − exp(− σt2 Λ)
dΣt|0 1 2 2
= − βt Dt Σt|0 + Σt|0 D⊺t + βt I.
(12b) 1
dt 2 Ht = exp( σt2 Λ),
2
Solving the differential equations with the help of integra-
which leads to Ct H−1 = Λ−1 I − exp(−σt2 Λ) . As a
tion factors, the mean process follows t
1
result, the corresponding forward transition writes
µt|0 = e− 2 [βD]t x0 , (13)
1
q
1
Rt µt|0 = exp(− σt2 Λ)x0 , Lt = Λ− 2 I − exp(−σt2 Λ).
where [βD]t = 0 βs Ds ds. By matrix decomposition 2
Σt|0 = Ct H−1 t (Särkkä & Solin, 2019), the covariance
In Corrolary 1 detailed in Appendix A, since the matrix
process follows that:
" Λ = diag(λ) is diagonal and time-invariant, the matrix ex-
# ponential and square root can be directly calculated element-
− 21 [βD]t
Ct [βI]t Σ0
= exp 1 ⊺ , (14) wise on each diagonal elements λi independently.
Ht 0 2 [βD ]t I
where the above matrix exponential can be easily com- 4.2.1. BACKWARD SDE
puted through modern computing libraries. Further, to
Taking the time reversal (Anderson, 1982) of the forward
avoid computing the expensive matrix exponential for high-
multivariate OU process (11), the backward SDE satisfies
dimensional problems, we can adopt a diagonal and time-
invariant Dt .
d←− = (− 1 D β ← − ←− p ← − . (17)
x t t t x t − βt st ( x t ))dt + βt d w t
2
Suppose Σt|0 has the Cholesky decomposition Σt|0 =
Lt L⊺t for some lower-triangular matrix Lt . We can have a Notably, with a general PD matrix Dt , the prior distribution
closed-form update that resembles the SGM. follows that xT ∼ N(0, ΣT |0 )‡ . We also note that the prior
is now limited to Gaussian distributions, which is not a
→
−x = µ + L ϵ,
t t|0 t general bridge anymore.
†
Dt = −2At ∈ Rd×d when the forward SDE is VE-SDE. ‡
See the Remark on the selection of ρprior in section B.1.
4
Variational Schrödinger Diffusion Models
4.2.2. P ROBABILITY F LOW ODE Empirically, if we want to exploit information from mul-
tiple modes, a standard extension is to employ the EMA
We can follow Song et al. (2021b) and obtain the determin-
technique (Trivedi & Kondor, 2017):
istic process directly:
(k) (k−1) (k)
Anh = (1 − η)Anh + ηAnh , where η ∈ (0, 1).
d←
x− = − 1D β ← − − 1 β s (←
x − ) dt,
x (18)
t t t t t t t
2 2
The EMA techniques are widely used empirically in dif-
where xT ∼ N(0, ΣT |0 ) and the sample trajectories follow fusion models and Schrödinger bridge (Song & Ermon,
the same marginal densities →
−
ρ t (xt ) as in the SDE. 2020; De Bortoli et al., 2021; Chen et al., 2022b) to avoid
oscillating trajectories. Now we are ready to present our
4.3. Adaptive Diffusion via Stochastic Approximation methodology in Algorithm 1.
Our major goal is to generate high-fidelity data with efficient Computational Cost Regarding the wall-clock compu-
transportation plans based on the optimal A⋆t in the forward tational time: i) training (linear) variational scores, albeit
process (11). However, the optimal A⋆t is not known a pri- in a simulation-based manner, becomes significantly faster
ori. To tackle this issue, we leverage stochastic approxima- than estimating nonlinear forward scores in Schrödinger
tion (SA) (Robbins & Monro, 1951; Benveniste et al., 1990) bridge; ii) the variational parametrization greatly reduced
(k)
to adaptively optimize the variational score At through the number of model parameters, which yields a much-
optimal transport and simulate the backward trajectories. reduced variance in the Hutchinson’s estimator (Hutchinson,
1989); iii) since we don’t need to update At as often as the
(1) Simulate backward trajectoriest {←−(k+1) }N −1 via the
x nh n=0
backward score model, we can further amortize the training
Euler–Maruyama (EM) scheme of the backward pro- of At . In the simulation example in Figure.9(b), VSDM is
cess (17) with a learning rate h. only 10% slower than the SGM with the same training com-
(k) −1 plexity of backward scores while still maintaining efficient
(2) Optimize variational scores Anh }N n=0 :
convergence of variational scores.
→
−
= Anh − ηk+1 ∇ L nh (Anh ; ←
(k+1) (k) (k) −(k+1)
Anh x nh ), 5. Convergence of Stochastic Approximation
→
− (k)
where ∇ L nh (Anh ; ←
(k) −(k+1) In this section, we study the convergence of At to the
x nh ) is the loss function (10) at
time nh and is known as the random field. We expect optimal A⋆t , where t ∈ [0, T ] § . The primary objective
is to show the iterates (19) follow the trajectories of the
that the simulation of backward trajectories {←
−(k+1) }N −1
x nh n=0
(k+1) (k+1) dynamical system asymptotically:
given snh helps the optimization of Anh and the opti-
(k+1) →
−
mized Anh in turn contributes to a more efficient trans- dAt = ∇ L t (At )ds, (20)
(k+2)
portation plan for estimating snh and simulating the back- (k+1)
At −At
(k) →
−
ward trajectories {← − where dA
(k+2) −1 and ∇ L t (·) is the mean
}N ds = limη→0
t
x nh n=0 . η
field at time t:
Trajectory Averaging The stochastic approximation al- →
−
Z
→
−
gorithm is a standard framework to study adaptive sampling ∇ L t (At ) = ∇ L t (At ; ←
−(·) )←
x t ρ t (d←
− −(·) ),
x t (21)
X
algorithms (Liang et al., 2007). Moreover, the formulation
suggests to stabilize the trajectories (Polyak & Juditsky, →
−
(k)
where X denotes the state space of data x and ∇ L t de-
1992) with averaged parameters Anh as follows notes the gradient w.r.t. At ; ←
−
ρ t is the distribution of the
continuous-time interpolation of the discretized backward
k
SDE (22) from t = T to 0. We denote by A⋆t one of the
(k) X (i) 1 (k−1) 1 (k)
Anh = Anh = 1− Anh + Anh , →
−
i=1
k k solutions of ∇ L t (A⋆t ) = 0.
(k) The aim is to find the optimal solution A⋆t to the mean
where Anh is known to be an asymptotically efficient (opti- →
−
field ∇ L t (A⋆t ) = 0. However, we acknowledge that the
mal) estimator (Polyak & Juditsky, 1992) in the local state equilibrium is not unique in general nonlinear dynamical
space A by assumption A1. systems. To tackle this issue, we focus our analysis around
a neighborhood Θ of the equilibrium by assumption A1.
Exponential Moving Average (EMA) Despite guaran- After running sufficient many iterations with a small enough
tees in convex scenarios, the parameter space differs tremen-
§ (k) (k)
dously in different surfaces in non-convex state space A. We slightly abuse the notation and generalize Anh to At .
5
Variational Schrödinger Diffusion Models
Algorithm 1 Variational Schrödinger Diffusion Models (VSDM). ρprior is fixed to a Gaussian distribution. ηk is the step
size for SA and h is the learning rate for the backward sampling of Eq.(17). ξ n denotes the standard Gaussian vector at the
sampling iteration n. The exponential moving averaging (EMA) technique can be used to further stabilize the algorithm.
repeat
Simulation-free Optimization of Backward Score
Draw x0 ∼ ρdata , n ∼ {0, 1, · · · , N − 1}, ϵ ∼ N(0, I).
(k)
Sample xnh |x0 ∼ N(µnh|0 , Σnh|0 ) by Eq.(13) and (14) given Anh .
−1 −⊺ N −1 N −1
Cache {µnh|0 }Nn=0 and {Lnh }n=0 via Cholesky decomposition of {Σnh }n=0 to avoid repeated computations.
(k+1) (k+1)
Optimize the score functions snh sufficiently through the loss function ∇θ ∥L−⊺
nh ϵ − snh (xnh )∥22 .
Optimization of Variational Score via Stochastic Approximation (SA)
Simulate the backward trajectory ←
−(k+1) given A(k) via Eq.(22), where ←
x nh nh x−(k+1) ∼ N(0, Σ(k)
(N −1) (N −1)h|0 ).
(k+1)
Optimize variational score Anh using the loss function (10), where n ∈ {0, 1, · · · , N − 1}:
→
−
= Anh − ηk+1 ∇ L nh (Anh ; ←
(k+1) (k) (k) −(k+1)
Anh x nh ). (19)
(k) (k) −1
step size ηk , suppose At ∈ Θ is somewhere near one estimators {st }N n=0 via the loss function (16) based on
equilibrium A⋆t (out of all equilibrium), then by the induc- (k)
the pre-specified At at step k. Similar in spirit to Chen
tion method, the iteration tends to get trapped in the same et al. (2023a; 2022a), we can show the generated samples
region as shown in Eq.(32) and yields the convergence to (k)
based on {st }N −1
n=0 are close in distribution to the ideal
one equilibrium A⋆t . We also present the variational gap of samples in Theorem 1. The novelty lies in the extension of
the (sub)-optimal transport and show our transport is more single-variate diffusions to multi-variate diffusions.
efficient than diffusion models with Gaussian marginals.
Next, we outline informal assumptions and sketch our main Theorem 1 (Generation quality, informal). Assume as-
(k)
results, reserving formal ones for readers interested in the sumptions A1-A4 hold with a fixed At , the generated data
details in the appendix. We also formulate the optimization distribution is close to the data distributions ρdata such that
of the variational score At using stochastic approximation √ √
TV(← − (k)
in Algorithm 2 in the supplementary material. ρ 0 , ρdata ) ≲ exp(−T ) + ( dh + ϵscore ) T .
(k)
Assumption A1 (Regularity). (Positive definiteness) For To show the convergence of At to A⋆t , the proof hinges
any t ≥ 0 and At ∈ A, Dt = I − 2At is positive definite. on a stability condition such that the solution asymptotically
(Locally strong convexity) For any stable local minimum A⋆t tracks the equilibrium A⋆t of the mean field (20).
→
−
with ∇ L t (A⋆t ) = 0, there is always a neighborhood Θ s.t.
→
− Lemma 2 (Local stability, informal). Assume the assump-
A⋆t ∈ Θ ⊂ A and L t is strongly convex in Θ.
tions A1 and A2 hold. For ∀t ∈ [0, T ] and ∀A ∈ Θ, the
By the mode-seeking property of the exclusive (reverse) solution satisfies a local stability condition such that
KL divergence (Chan et al., 2022), we only make a mild →
−
assumption on a small neighborhood of the solution and ⟨A − A⋆t , ∇ L t (A)⟩ ≳ ∥A − A⋆t ∥22 .
expect the convergence given proper regularities.
The preceding result illustrates the convergence of the solu-
Assumption A2 (Lipschitz Score). For any t ∈ [0, T ], the tion toward the equilibrium on average. The next assump-
score ∇ log →
−
ρ t is L-Lipschitz. tion assumes a standard slow update of the SA process,
which is standard for theoretical analysis but may not be
Assumption A3 (Second Moment Bound). The data dis- always needed in empirical evaluations.
tribution has a bounded second moment.
Assumption A5 (Step size). The step size {ηk }k∈N is a
Assumption A4 (Score Estimation Error). We have positive and decreasing sequence
bounded score estimation errors in L2 quantified by ϵscore . ∞
X ∞
X
ηk → 0, ηk = +∞, ηk2 < +∞.
We first use the multivariate diffusion to train our score k=1 k=1
6
Variational Schrödinger Diffusion Models
Next, we use the stochastic approximation theory to prove to Theorem 3 (Bunne et al., 2023) for the detailed trans-
(k)
the convergence of At to an equilibrium A⋆t . portation plans. Compared to the vanilla At ≡ 0, we can
significantly reduce the variational gap with KL(L∥L⋆ ) us-
Theorem 2 (Convergence in L2 ). Assume assumptions ing proper parametrization and sufficient training.
(k)
A1-A5 hold. The variational score At converges to an
equilibrium A⋆t in L2 such that 7. Empirical Studies
(k) 7.1. Comparison to Gaussian Schrodinger Bridge
E[∥At − A⋆t ∥22 ] ≤ 2ηk ,
VSDM is approximating GSB (Bunne et al., 2023) when
where the expectation is taken w.r.t samples from ←
− (k)
ρt .
both marginals are Gaussian distributions. To evaluate
In the end, we adapt Theorem 1 again to show the adaptively the solutions, we run our VSDM with a fixed βt ≡ 4 in
generated samples are asymptotically close to the samples Eq.(25) in Song et al. (2021b) and use the same marginals
based on the optimal A⋆t in Theorem 3, which quantifies the to replicate the VPSDE of the Gaussian SB with αt ≡ 0
quality of data based on more efficient transportation plans. and ct ≡ −2 in Eq.(7) in Bunne et al. (2023). We train
VSDM with 20 stages and randomly pick 256 samples for
Theorem 3 (Generation quality of adaptive samples). presentation. We compare the flow trajectories from both
Given assumptions A1-A5, the generated sample distribu- models and observe in Figure 1 that the ground truth solu-
tion at stage k is close to the exact sample distribution based tion forms an almost linear path, while our VSDM sample
on the equilibrium A⋆t such that trajectories exhibit a consistent alignment with trajectories
√ √ √ from Gaussian SB. We attribute the bias predominantly to
TV(← −
ρ ⋆0 , ρdata ) ≲ exp(−T ) + ( dh + ϵscore + ηk ) T . score estimations and numerical discretization.
6. Variational Gap
Recall that the optimal and variational forward SDEs follow
→
− − i
d→−
x t = f t (→ −x t ) + βt ∇ log ψ t (→
x t ) dt + βt d→
h p −
w t,
d→−
x t = f t (→ −x t ) + βt A t →
(k) −
x t dt + βt d→ −
h i p
w t, (a) GSB (b) VSDM
d→−
x t = f t (→ −
x t ) + βt A⋆t →
−x t dt + βt d→ −
p
w t, Figure 1. Gaussian v.s. VSDM on the flow trajectories.
where we abuse the notion of → −x t for the sake of clarity 7.2. Synthetic Data
and they represent three different processes. Despite the
improved efficiency based on the ideal A⋆t compared to the We test our variational Schrödinger diffusion models (VS-
vanilla At ≡ 0, the variational score inevitably yields a DMs) on two synthetic datasets: spiral and checkerboard
sub-optimal transport in general nonlinear transport. We (detailed in section D.2.1). We include SGMs as the baseline
denote the law of the above processes by L, L(k) , and L⋆ . models and aim to show the strength of VSDMs on general
To assess the disparity, we leverage the Girsanov theorem shapes with straighter trajectories. As such, we stretch the
to study the variational gap. Y-axis of the spiral data by 8 times and the X-axis of the
checkerboard data by 6 times and denote them by spiral-8Y
Theorem 3 (Variational gap). Assume the assumption A2 and checkerboard-6X, respectively.
→
−
and Novikov’s condition hold. Assume f t and ∇ log ψ t We adopt a monotone increasing {βnh }N −1
n=0 similar to Song
are Lipschitz smooth and satisfy the linear growth. The et al. (2021b) and denote by βmin and βmax the minimum
variational gap follows that −1
and maximum of {βnh }N n=0 . We fix ζ = 0.75 and βmin =
0.1 and we focus on the study with different βmax . We find
1 T
→
− →
Z
⋆ ⋆→
− − 2 that SGMs work pretty well with βmax = 10 (SGM-10)
KL(L∥L ) = E βt ∥At x t − ∇ log ψ t ( x t )∥2 dt
2 0 on standard isotropic shapes. However, when it comes to
KL(L∥L(k) ) ≲ ηk + KL(L∥L⋆ ). spiral-8Y, the SGM-10 struggles to recover the boundary
regions on the spiral-8Y data as shown in Figure 2 (top).
Connections to Gaussian Schrödinger bridge (GSB)
When data follows a Gaussian distribution, VSDM approxi- Generations of Anisotropic Shapes To illustrate the ef-
mates the closed-form OT solution of Schrödinger bridge fectiveness of our approach, Figure 2 (bottom) shows that
(Janati et al., 2020; Bunne et al., 2023). We refer readers VSDM-10 accurately reconstructs the edges of the spiral
7
Variational Schrödinger Diffusion Models
Figure 2. Variational Schrödinger diffusion models (VSDMs, bot- Figure 4. Unconditional generated samples from VSDM on CI-
tom) v.s. SGMs (top) with the same hyperparameters (βmax = 10). FAR10 (32×32 resolution) trained from scratch.
8
Variational Schrödinger Diffusion Models
FID↓ (NFE=35) 406.13 13.13 8.65 6.83 5.66 5.21 3.62 3.29 3.01 2.28
Table 2. CIFAR10 EVALUATION USING SAMPLE QUALITY (FID Table 3. F ORECASTING RESULTS ( LOWER IS BETTER ).
SCORE ). O UR VSDM OUTPERFORMS OTHER OPTIMAL TRANS - CRPS- SUM E LECTRICITY E XCHANGE RATE S OLAR
PORT BASELINES BY A LARGE MARGIN . DDPM 0.026±0.007 0.012±0.001 0.506±0.058
C LASS M ETHOD FID ↓ SGM 0.045±0.005 0.012±0.002 0.413±0.045
VSDM ( OURS ) 2.28 VSDM ( OUR ) 0.038±0.006 0.008±0.002 0.395±0.011
9
Variational Schrödinger Diffusion Models
10
Variational Schrödinger Diffusion Models
Huang, C.-W., Lim, J. H., and Courville, A. A Variational Lipman, Y., Chen, R. T. Q., Ben-Hamu, H., Nickel, M.,
Perspective on Diffusion-Based Generative Models and and Le, M. Flow Matching for Generative Modeling.
Score Matching. In Advances in Neural Information In Proc. of the International Conference on Learning
Processing Systems (NeurIPS), 2021. Representation (ICLR), 2023.
Hutchinson, M. F. A Stochastic Estimator of the Trace of the Liptser, R. S. and Shiryaev, A. N. Statistics of Random Pro-
Influence Matrix for Laplacian Smoothing Splines. Com- cesses: I. General Theory. Springer Science & Business
munications in Statistics-Simulation and Computation, Media, 2001.
18(3):1059–1076, 1989.
Liu, Q. Rectified Flow: A Marginal Preserving Approach
Hyvärinen, A. Estimation of Non-normalized Statistical to Optimal Transport. arXiv:2209.14577, 2022.
Models by Score Matching. Journal of Machine Learning
Liu, X., Gong, C., and Liu, Q. Flow Straight and Fast:
Research, 6(24):695–709, 2005.
Learning to Generate and Transfer Data with Rectified
Janati, H., Muzellec, B., Peyré, G., and Cuturi, M. En- Flow. In International Conference on Learning Repre-
tropic Optimal Transport between Unbalanced Gaussian sentation (ICLR), 2023.
Measures has a Closed Form. In Advances in Neural
Lu, C., Zhou, Y., Bao, F., Chen, J., Li, C., and Zhu, J. DPM-
Information Processing Systems (NeurIPS), 2020.
Solver: A Fast ODE Solver for Diffusion Probabilistic
Karatzas, I. and Shreve, S. E. Brownian Motion and Stochas- Model Sampling in Around 10 Steps. In Advances in
tic Calculus. Springer, 1998. Neural Information Processing Systems (NeurIPS), 2022.
Karras, T., Aittala, M., Aila, T., and Laine, S. Elucidating Luo, W. A Comprehensive Survey on Knowledge
the Design Space of Diffusion-Based Generative Models. Distillation of Diffusion Models. arXiv preprint
In Advances in Neural Information Processing Systems arXiv:2304.04262, 2023.
(NeurIPS), 2022. Luo, W., Hu, T., Zhang, S., Sun, J., Li, Z., and Zhang,
Kingma, D. P., Salimans, T., Poole, B., and Ho, J. Varia- Z. Diff-instruct: A Universal Approach for Transferring
tional Diffusion Models. ArXiv, abs/2107.00630, 2021. Knowledge from Pre-trained Diffusion Models. Advances
in Neural Information Processing Systems, 36, 2024a.
Koehler, F., Heckett, A., and Risteski, A. Statistical Effi-
ciency of Score Matching: The View from Isoperimetry. Luo, W., Zhang, B., and Zhang, Z. Entropy-based Training
In ICLR, 2023. Methods for Scalable Neural Implicit Samplers. NeurIPS,
36, 2024b.
Kong, Z., Ping, W., Huang, J., Zhao, K., and Catanzaro,
B. DiffWave: A Versatile Diffusion Model for Audio Ma, J. and Yong, J. Forward-Backward Stochastic Differen-
Synthesis . In Proc. of the International Conference on tial Equations and their Applications. Springer, 2007.
Learning Representation (ICLR), 2021. Marzouk, Y., Moselhy, T., Parno, M., and Spantini, A. Sam-
Kullback, S. Probability Densities with Given Marginals. pling via Measure Transport: An Introduction. Handbook
Ann. Math. Statist., 1968. of Uncertainty Quantification, pp. 1–41, 2016.
McCann, R. J. A Convexity Principle for Interacting Gases.
Lavenant, H. and Santambrogio, F. The Flow Map of
Advances in mathematics, 128(1):153–179, 1997.
the Fokker–Planck Equation Does Not Provide Optimal
Transport. Applied Mathematics Letters, 133, 2022. Øksendal, B. Stochastic Differential Equations: An Intro-
duction with Applications. Springer, 2003.
Lee, H., Lu, J., and Tan, Y. Convergence for Score-
based Generative Modeling with Polynomial Complex- Onken, D., Fung, S. W., Li, X., and Ruthotto, L. OT-Flow:
ity. Advances in Neural Information Processing Systems Fast and Accurate Continuous Normalizing Flows via
(NeurIPS), 2022. Optimal Transport. In Proc. of the National Conference
on Artificial Intelligence (AAAI), 2021.
Léonard, C. A Survey of the Schrödinger Problem and Some
of its Connections with Optimal Transport. Discrete Pavon, M., Tabak, E. G., and Trigila, G. The Data-driven
& Continuous Dynamical Systems-A, 34(4):1533–1574, Schrödinger Bridge. Communications on Pure and Ap-
2014. plied Mathematics, 74:1545–1573, 2021.
Liang, F., Liu, C., and Carroll, R. J. Stochastic Approx- Peluchetti, S. Diffusion Bridge Mixture Transports,
imation in Monte Carlo Computation. Journal of the Schrödinger Bridge Problems and Generative Modeling.
American Statistical Association, 102:305–320, 2007. ArXiv e-prints arXiv:2304.00917v1, 2023.
11
Variational Schrödinger Diffusion Models
Peyré, G. and Cuturi, M. Computational Optimal Transport: Song, Y., Durkan, C., Murray, I., and Ermon, S. Maximum
With Applications to Data Science. Foundations and Likelihood Training of Score-Based Diffusion Models .
Trends in Machine Learning, 2019. In Advances in Neural Information Processing Systems
(NeurIPS), 2021a.
Polyak, B. T. and Juditsky, A. Acceleration of Stochastic
Approximation by Averaging. SIAM Journal on Control Song, Y., Sohl-Dickstein, J., Kingma, D. P., Kumar, A., Er-
and Optimization, 30:838–855, 1992. mon, S., and Poole, B. Score-Based Generative Modeling
through Stochastic Differential Equations . In Interna-
Pooladian, A.-A., Ben-Hamu, H., Domingo-Enrich, C., tional Conference on Learning Representation (ICLR),
Amos, B., Lipman, Y., and Chen, R. T. Q. Multisam- 2021b.
ple Flow Matching: Straightening Flows with Minibatch
Couplings. In ICML, 2023. Tanaka, A. Discriminator Optimal Transport. In Neural
Information Processing Systems, 2019.
Ramesh, A., Dhariwal, P., Nichol, A., Chu, C., and Chen,
M. Hierarchical Text-Conditional Image Generation with Tong, A., Malkin, N., Huguet, G., Zhang, Y., Rector-Brooks,
CLIP Latents. In arXiv:2204.06125v1, 2022. J., Fatras, K., Wolf, G., and Bengio, Y. Improving and
Generalizing Flow-based Generative Models with Mini-
Rasul, K., Seward, C., Schuster, I., and Vollgraf, R. Au- batch Optimal Transport. arXiv:2302.00482v3, 2023.
toregressive Denoising Diffusion Models for Multivariate
Trivedi, S. and Kondor, R. Optimization for Deep Neural
Probabilistic Time Series Forecasting. In International
Networks. Slides - University of Chicago, 2017.
Conference on Machine Learning, 2021.
Vahdat, A., Kreis, K., and Kautz, J. Score-based Gener-
Robbins, H. and Monro, S. A Stochastic Approximation ative Modeling in Latent Space. Advances in Neural
Method. Annals of Mathematical Statistics, 22:400–407, Information Processing Systems, 34:11287–11302, 2021.
1951.
Vanden-Eijnden, E. Introduction to Regular Perturbation
Ruschendorf, L. Convergence of the Iterative Proportional Theory. Slides, 2001. URL https://cims.nyu.
Fitting Procedure. Ann. of Statistics, 1995. edu/˜eve2/reg_pert.pdf.
Salimans, T. and Ho, J. Progressive Distillation for Fast Vargas, F., Thodoroff, P., Lamacraft, A., and Lawrence, N.
Sampling of Diffusion Models. In ICLR, 2022. Solving Schrödinger Bridges via Maximum Likelihood.
Entropy, 23(9):1134, 2021.
Salinas, D., Bohlke-Schneider, M., Callot, L., Medico, R.,
and Gasthaus, J. High-dimensional Multivariate Forecast- Vempala, S. S. and Wibisono, A. Rapid Convergence of the
ing with Low-rank Gaussian Copula Processes. Advances Unadjusted Langevin Algorithm: Isoperimetry Suffices,
in neural information processing systems, 2019. 2022.
Särkkä, S. and Solin, A. Applied Stochastic Differential Villani, C. Topics in Optimal Transportation, volume 58.
Equations. Cambridge University Press, 2019. American Mathematical Soc., 2003.
Shi, Y., De Bortoli, V., Campbell, A., and Doucet, A. Dif- Vono, M., Paulin, D., and Doucet, A. Efficient MCMC
fusion Schrödinger Bridge Matching. In Advances in Sampling with Dimension-Free Convergence Rate using
Neural Information Processing Systems (NeurIPS), 2023. ADMM-type Splitting. Journal of Machine Learning
Research, 2022.
Singhal, R., Goldstein, M., and Ranganath, R. Where to
Xue, S., Yi, M., Luo, W., Zhang, S., Sun, J., Li, Z., and
Diffuse, How to Diffuse, and How to Get Back: Auto-
Ma, Z.-M. SA-Solver: Stochastic Adams Solver for
mated Learning for Multivariate Diffusions. In Proc. of
Fast Sampling of Diffusion Models. Advances in Neural
the International Conference on Learning Representation
Information Processing Systems, 2023.
(ICLR), 2023.
Zhang, B., Luo, W., and Zhang, Z. Enhancing Adversarial
Song, Y. and Ermon, S. Improved Techniques for Training
Robustness via Score-Based Optimization. Advances in
Score-Based Generative Models. In Advances in Neural
Neural Information Processing Systems, 36, 2024.
Information Processing Systems (NeurIPS), 2020.
Song, Y., Garg, S., Shi, J., and Ermon, S. Sliced Score
Matching: A Scalable Approach to Density and Score Es-
timation. In Uncertainty in Artificial Intelligence, 2020.
12
Variational Schrödinger Diffusion Models
In section A, we study the closed-form expression of matrix exponential for diagonal and time-invariant Dt ; In section
B, we study the convergence of the adaptive diffusion process; In section C, we study the variational gap of the optimal
transport and discuss its connections to Gaussian Schrödinger bridge; In section D, we present more details on the empirical
experiments.
Notations: X is the state space for the data x; ← −(k) is the n-th backward sampling step with a learning rate h at the k-th
x nh
(k)
stage. ηk is the step size to optimize A. A is the (latent) state space of A; At is the forward linear score estimator at
→
−
stage k and time t, A⋆t is the equilibrium of Eq.(25) at time t. ∇ L t is the random field in the stochastic approximation
→
− (k)
process and also the loss (10) at time t; ∇ L t is the mean field with the equilibrium A⋆t . Given a fixed At at step k,
∇ log →
−ρ t (resp. ∇ log → −
(k) (k) (k)
ρ t|0 ) is the (resp. conditional) forward score function of Eq.(11) at time t and step k; At yields
the approximated score function s and ← −
(k) (k)
t ρ t is the distribution of the continuous-time interpolation of the discretized
backward SDE (22).
Therefore, we have
(− 12 σt2 Λ)2
1 2 3
(− 2 σt Λ) σt2 ( 12 σt2 Λ)2
0
M2t = , M 3
= ,
0 ( 12 σt2 Λ)2 t 0 ( 12 σt2 Λ)3
(− 12 σt2 Λ)4
1 2 5
(− 2 σt Λ) σt2 ( 12 σt2 Λ)4
0
M4t = , M 5
= , ...
0 ( 12 σt2 Λ)4 , t 0 ( 12 σt2 Λ)5
According to the definition of matrix exponential, we have
1 1 1
exp(Mt ) = [I + Mt + M2t + M3t + ...]
1! 2! h 3! i!
1 2 2 1 2 1 2 1 1 2
exp(− 2 σt Λ) σt I + 3! σt ( 2 σt Λ)2 + 5! ( 2 σt Λ)4 + ...
=
0 exp( 21 σt2 Λ)
σ2
h i!
exp(− 12 σt2 Λ) 1 σt2 Λ ( 21 σt2 Λ)1 + 3! 1 2 1 2
σt ( 2 σt Λ)3 + 5!1 1 2
( 2 σt Λ)5 + ...
= 2 t
0 exp( 12 σt2 Λ)
h i!
exp(− 12 σt2 Λ) Λ−1 exp( 12 σt2 Λ) − exp(− 21 σt2 Λ)
= .
0 exp( 12 σt2 Λ)
13
Variational Schrödinger Diffusion Models
Therefore, Ct H−1 −1
t =Λ I − exp(−σt2 Λ) . As a result, the corresponding forward transition writes
1
µt|0 = exp(− σt2 Λ)x0
2q
1
Lt = σt Λ− 2 I − exp(−σt2 Λ).
B. Stochastic Approximation
Stochastic approximation (SA), also known as the Robbins–Monro algorithm (Robbins & Monro, 1951; Benveniste et al.,
1990) offers a conventional framework for the study of adaptive algorithms. The stochastic approximation algorithm works
by repeating the sampling-optimization iterations in the dynamic setting in terms of simulated trajectories. We present our
algorithm in Algorithm 2.
Algorithm 2 The (dynamic) stochastic approximation (SA) algorithm. The (dynamic) SA is a theoretical formulation of
(k+1)
Algorithm 1. We assume optimizing the loss function (16) yields proper score estimations st at each stage k and time t
ρ t (→
(k) − →
to approximate ∇ log →
− x t |−
x 0 ) in Eq.(9b).
repeat
(k)
Simulation: Sample the backward process from (17) given a fixed Anh
←− (k+1) 1 (k) ←
− (k+1) (k+1) ←− (k+1)
p
x (n−1)h = − I − 2Anh βnh x nh − βnh snh x nh h + βnh hξ n , (22)
2
To facilitate the analysis, we assume we only make a one-step sampling in Eq.(23). Note that it is not required in practice
and multiple-step extensions can be employed to exploit the cached data more efficiently. The theoretical extension is
straightforward and omitted in the proof. We also slightly abuse the notation for convenience and generalize Anh to At .
Theoretically, the primary objective is to show the iterates (19) follow the trajectories of the dynamical system asymptotically:
→
−
dAt = ∇ L t (At )ds, (24)
→
−
where ∇ L t (At ) is the mean field defined as follows:
→
− →
−
Z
∇ L t (At ) = ∇ L t (At ; ←
−(·) )←
x t ρ t (d←
− −(·) ).
x t (25)
X
→
−
We denote by A⋆t the solution of ∇ L t (A⋆t ) = 0. Since the samples simulated from ← −
ρ t are slightly biased due to the
convergence of forward process, discretization error, and score estimation errors as shown in Theorem 1. We expect the
mean field is also biased with a perturbed equilibrium. However, by the perturbation theory (Vanden-Eijnden, 2001), the
perturbation is mild and controlled by the errors in Theorem 1. Hence although A⋆t is not the optimal linear solution in
terms of optimal transport, it still yields efficient transportation plans.
Since the exclusive (reverse) KL divergence is known to approximate a single mode (denoted by A⋆t ) in fitting multi-modal
distributions, we proceed to assume the following regularity conditions for the solution A⋆t and the neighborhood of A⋆t .
Assumption A1 (Regularity). (Positive definiteness) For any t ≥ 0 and At ∈ A, there exists a constant λmin > 0 s.t.
λmin I ≼ Dt = I − 2At , where A ≼ B means B − A is semi positive definite. (Locally strong convexity) For any stable
14
Variational Schrödinger Diffusion Models
→
− →
−
local minimum A⋆t with ∇ L t (A⋆t ) = 0, there is always a neighborhood Θ s.t. A⋆t ∈ Θ ⊂ A and L t is strongly convex in
→
−
∂2 L t
Θ, i.e. there exists fixed constants M > m > 0 s.t. for ∀A ∈ Θ, mI ≼ ∂A2 (A) ≼ M I.
The first part of the above assumption is standard and can be achieved by an appropriate regularization during the training;
the second part only assumes the strong convexity for a small neighborhood Θ of the optimum A⋆t . As such, when conditions
(k)
for Eq.(31) hold, we can apply the induction method to make sure all the subsequent iterates of At stay in the same region
⋆
Θ and converge to the local minimum At . For future works, we aim to explore the connection between m and λmin .
Next, we lay out three standard assumptions following Chen et al. (2023a) to conduct our analysis. Similar results are
studied by Lee et al. (2022); Chen et al. (2022a) with different score assumptions.
Assumption A2 (Lipschitz Score). The score function ∇ log → −ρ t (∇ log →
−
ρ t,A )|| is L-Lipschitz in both x and A for any
t ∈ [0, T ]. For any A, B ∈ A and any x, y ∈ X , we have
∥∇ log → −ρ (x) − ∇ log →
t,A
−ρ (y)∥ ≤ L∥x − y∥
t,A 2 2
∥∇ log →
−
ρ t,A (x) − ∇ log →
−
ρ t,B (y)∥2 ≤ L∥A − B∥
where ∥ · ∥2 is the standard L2 norm and ∥ · ∥ is matrix norm.
Assumption A3 (Second Moment Bound). The data distribution has a bounded second moment m22 := Eρdata [∥ · ∥22 ] < ∞.
Assumption A4 (Score Estimation Error). For all t ∈ [0, T ], and any At , we have some estimation error .
E→ →
−
− [∥s − ∇ log ρ ∥2 ] ≤ ϵ2 .
ρt t t 2 score
−1 (k)
We first use the multivariate diffusion to train our score estimators {st }N n=0 via the loss function (16) based on the
(k) (k) −1
pre-specified At . Following Chen et al. (2023a), we can show the generated samples based on {st }N n=0 are close in
distribution to the ideal samples in Theorem 1. The novelty lies in the extension of single-variate diffusions to multi-variate
diffusions.
(k)
Next, we use the stochastic approximation theory to prove the convergence of At to a local equilibrium A⋆t in Theorem 2.
In the end, we adapt Theorem 1 again to show the adaptively generated samples are asymptotically close to the samples
based on the optimal A⋆t in Theorem 3, which further optimizes the transportation plans through a variational formulation.
To facilitate the understanding, we summarize the details as follows
(k) (k)
Sample via At Random Field Mean Field Convergence of At Sample via A⋆t
−
→ −
→
==========⇒ ∇ L t (At ; ←
(k) −(k+1) Eq.(25)
) ====⇒ ∇ L t (At ) =======⇒ At → A⋆t ==========⇒ lim ← −(k+1) .
(k) Backward Sampling (k) Convergence (k) Adaptive Sampling
st xt x t
Theorem 1 Theorem 2 Theorem 3 k→∞
Proof of Sketch
(k)
• Part B.1: The generated samples (backward trajectories) approximate the ideal samples from the fixed At .
(k)
• Part B.2: We employ the SA theory to show the convergence At to the optimal estimator A⋆t .
• Part B.3: The adaptively generated samples approximate the ideal samples from the optimal A⋆t asymptotically.
15
Variational Schrödinger Diffusion Models
Theorem 1 (Generation quality). Assume assumptions A2, A3, and A4 hold. Given a fixed At by assumption A1, the
generated data distribution via the EM discretization of Eq.(17) is close to the data distributions ρdata such that
q √ √ √
TV(← −
ρ 0 , ρdata ) ≲ KL(ρdata ∥γ d ) exp(−T ) + (L dh + m2 h) T + ϵscore T ,
| {z } | {z } | {z }
EM discretization score estimation
convergence of forward process
Proof Following Chen et al. (2023a), we employ the chain rule for KL divergence and obtain:
KL(ρdata ∥←
−
ρ 0 ) ≤ KL(→
−
ρ T ∥←
−
ρ T ) + E→ →
− ←−
ρ T (x) [KL( ρ 0|T (·∥x)| ρ 0|T (·∥x)],
−
where →
−
ρ 0|T is the conditional distribution of x0 given xT and likewise for ←
−
ρ 0|T . Note that the two terms correspond to the
convergence of the forward and reverse process respectively. We proceed to prove that
where
Z →
−
ρ t (x)
2
J→
−
ρ◦ (→
−
ρ t) = →
−
ρ t (x) ∇ ln →
− dx
t
ρ ◦t (x)
KL(→
−
ρ t ∥→
− KL(→
−
ρ 0 ∥→
− KL(→
−
ρ 0 ∥→
−
Rt Rt
ρ ◦t ) ≤ e− 0
αs βs ds
ρ ◦0 ) ≤ e−α 0
βs ds
ρ ◦0 ),
where the last inequality is followed by Lemma 1 and α is a lower bound estimate of the LSI constant inf t∈[0,T ] αt .
Then by Pinsker’s Inequality, we have
q q q
TV(→ −
ρ t, →
−
ρ ◦t ) ≤ 2KL(→ −
ρ t ∥→
−
ρ ◦t ) ≤ 2e−α 0 βs ds KL(→
−
ρ 0 ∥→
−
Rt
ρ ◦0 ) ≲ KL(ρdata ∥γ d ) exp(−t).
Part II: The proof for the convergence of the reverse process is essentially identical to Theorem 2.1 of Chen et al. (2023a),
with the only potential replacements being instances of ∥xt − xkh ∥2 with ∥DT −t (xt − xkh )∥2 . However, they are equivalent
due to Assumption A1. Therefore, we omit the proof here.
In conclusion, the convergence follows that
KL(ρdata ∥←
−
ρ 0 ) ≲ KL(ρdata ∥γ d )e−T + (L2 dh + m22 h2 )T + ϵscore T.
16
Variational Schrödinger Diffusion Models
Lemma 1 (Lower bound of the log-Sobolev constant). Under the same assumptions and setups in Theorem 1, we have
• Randomness from the initial: By the mean diffusion in Eq.(12a), the conditional mean diffusion of yt at time t, denoted
1
by µt,y , follows that µt,y = Dt µ0,y , where Dt = e− 2 [βD]t . Since y0 ∼ N(0, I), we know µt,y ∼ N(0, Dt D⊺t ).
• Randomness from Brownian motion: the covariance diffusion induced by Brownian motion follows from Σt|0 in
Eq.(12b).
Since y0 ∼ N(0, I) and yt is an OU process in Eq.(11), we know that yt is always a Gaussian distribution at time t ≥ 0
with mean 0. As such, we know that
→
−
ρ ◦ = N(0, D D⊺ + Σ ). (26)
t t t t|0
It follows that
q
TV(→
−
ρ t, →
− KL(→
−
ρ 0 ∥→
−
Rt
ρ ◦t ) ≤ 2e− 0
αs βs ds
ρ ◦0 ).
Now we need to bound the log-Sobolev constant αt of → −ρ ◦t . Let Σt = Dt D⊺t + Σt|0 . Recall that if a distribution p is
α-strongly log-concave, then it satisfies the log-Sobolev inequality (LSI) with LSI constant α (Vempala & Wibisono, 2022).
So for the Gaussian distribution →
−ρ ◦t , it suffices to bound the (inverse of) smallest eigenvalue of Σt . Recall from Eq.(12b)
that Σt satisfies the ODE
dΣt 1
= − βt (Dt Σt + Σt D⊺t ) + βt I, Σ0 = I.
dt 2
Fix a normalized vector x ∈ Rd and denote ut = x⊺ Σt x for t ∈ [0, T ]. By the cyclical property of the trace, we have
It follows that
dut
≤ −λmin βt ut + βt .
dt
1 RT RT
ut ≤ (1 − e−λmin 0
βs ds
) + e−λmin 0
βs ds
≤ max{1, 1/λmin }.
λmin
Since x can be any normalized vector, we have that the largest eigenvalue of Σt is bounded by max{1, 1/λmin } and hence
Remark: In our theoretical analysis, we introduced an auxiliary variable y0 ∼ γ d to make sure KL(ρdata ∥γ d ) is well
defined. Moreover, the distribution of yT is set to →
−
ρ ◦T in Eq.(26). However, we emphasize that the introduction of yt is only
for theoretical analysis and we adopt a simpler prior N(0, ΣT |0 ) instead of N(0, DT D⊺T + ΣT |0 ) in Eq.(26) for convenience.
17
Variational Schrödinger Diffusion Models
Proof By the smoothness assumption A2 and Taylor expansion, for any A ∈ Θ, we have
→
− →
− →
− e →
− e
∇ L t (A) = ∇ L t (A⋆ ) + Hess L t A (A − A⋆ ) = Hess L t A (A − A⋆ ), (27)
→
− →
− e is some value between A and A⋆ by the mean-value
where Hess L t A denotes the Hessian of L t with A at time t; A t
theorem. Next, we can get
→
− →
− e
⟨A − A⋆t , ∇ L t (A)⟩ = Hess L t A ∥A − A⋆ ∥22 ≥ m∥A − A⋆ ∥22 ,
It follows that
→
− →
− →
− →
−
E[∥∇ L t (At ; ←
(k) −(k+1) 2
)∥2 |Fk ] = E[∥∇ L t (At ; ←
(k) −(k+1) (k) (k)
xt xt ) − ∇ L t (At ) + ∇ L t (At ))∥22 |Fk ]
→
− →
− →
−
= E[∥∇ L t (At ; ←
(k) −(k+1) (k) (k)
xt ) − ∇ L t (At )∥22 |Fk ] + ∥∇ L t (At )∥22 (29)
→
− →
−
≤ sup E[∥∇ L (A ; ←
(k) −(k+1) (k) (k)
t x t ) − ∇ L (A )∥2 |F ] + M 2 ∥A − A⋆ ∥2 ,
t t t 2 k t t 2
Next, we make standard assumptions on the step size following Benveniste et al. (1990) (page 245).
Assumption A5 (Step size). The step size {ηk }k∈N is a positive and decreasing sequence
∞
X ηk ηk+1 − ηk
ηk → 0, ηk = +∞, lim inf 2m + 2 := κ > 0.
k→∞ ηk+1 ηk+1
k=1
A
A standard choice is to set ηk := kα +B for some α ∈ ( 12 , 1] and some suitable constants A > 0 and B > 0.
18
Variational Schrödinger Diffusion Models
(k)
Theorem 2 (Convergence in L2 ). Assume assumptions A1, A2, A3, A4, and A5 hold. The variational score At in algorithm
2 converges to a local minimizer A⋆t . In other words, given a large enough k ≥ k0 , where ηk0 ≤ 12 , we have
(k)
E[∥At − A⋆t ∥22 ] ≤ 2ηk ,
(k)
• Given some large enough k ≥ k0 , where ηk0 ≤ 21 , At is in some subset Θ ** of A that follows
where the first inequality is held by the stability property in Lemma 2 and the last inequality is followed by the growth
property in Lemma 3.
(k) (k+1)
Since A⋆t , At ∈ Θ, Eq.(32) implies that At ∈ Θ, which concludes the proof.
** By assumption A1, such Θ ⊂ A exists, otherwise it implies that the mean field function is a constant and conclusion holds as well.
19
Variational Schrödinger Diffusion Models
(k)
Proof By assumption A4, for any At ∈ A, we have
− ∇ log →
−
(k) (k)
E→
−
ρ
(k) [∥st ρ t ∥22 ] ≤ ϵ2score .
t
Combining Theorem 2 and the smoothness assumption A2 of the score function ∇ log →
− (k) (k)
ρt w.r.t At , we have
E→ →
−
(k) [∥∇ log ρ t
(k)
− ∇ log →
−
ρ ⋆t ∥22 ] ≲ ηk . (33)
−
ρ t
− ∇ log →
−
(k)
E→
−
ρ
(k) [∥st ρ ⋆t ∥22 ]
t
− ∇ log →
− →
− − ∇ log →
−
(k) (k)
≲ E→
− (k) [∥s
ρt | t
ρ t ∥22 ] + E→
− (k) [∥∇ log ρ t
ρt |
ρ ⋆t ∥22 ] (34)
{z } {z }
by Assumption A4 by Eq.(33)
≲ ϵ2score + ηk .
Applying Theorem 1 with the adaptive score error in Eq.(34) to replace ϵ2score concludes the proof.
Remark: The convergence of samples based on the adaptive algorithms is slightly weaker than the standard one due to the
adaptive update, but this is necessary because A⋆t is more transport efficient than a vanilla At .
C. Variational Gap
Recall that the optimal forward SDE in the forward-backward SDEs (5) follows that
→
− − i
d→
−
x t = f t (→− x t ) dt + βt d→
x t ) + βt ∇ log ψ t (→
h p −
w t. (35)
d→−
x t = f t (→
−
x t ) + βt A⋆t →
−
x t dt + βt d→
p −
w t. (36)
d→−
x t = f t (→−
x t ) + β t At →
(k) −
x t dt + βt d→
h i p −
w t. (37)
Since we only employ a linear approximation of the forward score function, our transport is only sub-optimal. To assess the
extent of this discrepancy, we leverage the Girsanov theorem to study the variational gap.
We denote the law of the processes by L(·) in Eq.(35), L⋆ (·) in Eq.(36) and L(k) (·) in Eq.(37), respectively.
→
−
Theorem 4. Assume assumptions A2 and A3 hold. Assume f t and ∇ log ψ t are Lipschitz smooth and satisfy the linear
growth condition. Assume the Novikov’s condition holds for ∀At ∈ A, where t ∈ [0, T ]:
Z T
1 →
− − 2
E exp ∥βt At →
−
x t − βt ∇ log ψ t (→
x t )∥2 dt < ∞.
2 0
20
Variational Schrödinger Diffusion Models
The variational gap (VG) via the linear parametrization is upper bounded by
T
→
− →
Z
1 ⋆→− −
KL(L∥L⋆ ) = E→
−
ρt β t ∥A t x t − ∇ log ψ t ( x )∥
t 2
2
dt
2 0
KL(L∥L(k) ) ≲ ηk + KL(L∥L⋆ ).
Proof
By Girsanov’s formula (Liptser & Shiryaev, 2001), the Radon–Nikodym derivative of L(·) w.r.t. L⋆ (·) follows that
T
1 T
Z
→
− − →
− − 2
Z
dL → − βt A⋆t →
−
x t − ∇ log ψ t (→ βt ∥A⋆t →
−
x t − ∇ log ψ t (→
p
⋆
x = exp x t ) dwt − x t )∥2 dt ,
dL 0 2 0
where wt is the Brownian motion under the Wiener measure. Consider a change of measure (Øksendal, 2003; Chewi, 2023)
→
− −
βt A⋆t →
−
x t − ∇ log ψ t (→
p
e t − d w, M t ,
wt = w dMt = x t ) , dwt ,
where w
e t is a L-standard Brownian motion and satisfies martingale property under the L measure.
Now the variational gap is upper bounded by
⋆ dL(·)
KL(L(·)∥L (·)) = −EL(·) log ⋆
dL (·)
Z T
1 T
→
− − →
− − 2
Z
βt A⋆t →−
x t − ∇ log ψ t (→ βt ∥A⋆t →
−
x t − ∇ log ψ t (→
p
= EL(·) x t ) dw et + x t )∥2 dt
0 2 0
Z T
1 →
− − 2
= EL(·) βt ∥A⋆t →
−
x t − ∇ log ψ t (→
x t )∥2 dt
2 0
Z T
1 ⋆→− →
− → − 2
= E βt ∥At x t − ∇ log ψ t ( x t )∥2 dt.
2 0
D. Experimental Details
D.1. Parametrization of the Variational Score
For the general transport, there is no closed-form update and we adopt an SVD decomposition with time embeddings to
learn the linear dynamics in Figure 6. The number of parameters is reduced by thousands of times, which have greatly
reduced the training variance (Grathwohl et al., 2019).
21
Variational Schrödinger Diffusion Models
Figure 6. Architecture of the linear module. Both U and V are orthogonal matrices and Λ denotes the singular values.
Figure 7. Variational Schrödinger diffusion models (VSDMs, right) v.s. SGMs (left) with the same hyperparameters (βmax = 10).
5 5 5 5
2 2 2 2
1 1 1 1
4 4 4 4
7 40 20 0 20 7 40 20 0 20 7 40 20 0 20 7 40 20 0 20
Computational Time We tried different budgets to train the variational scores and observed in Figure 9(b) that 300
iterations yield the fastest convergence among the 4 choices but also lead to 23% extra time compared to SGM. Reducing the
number of iterations impacts convergence minimally due to the linearity of the variational scores and significantly reduces
the training time.
22
Variational Schrödinger Diffusion Models
metric by approximating the second derivative of the probability flow (18) as follows
T
d2 ←
− (i)
Z
x t
S(i) = E←
− ∼←
x −
ρt dt, (38)
0
t
dt2
Table 4. S TRAIGHTNESS METRIC DEFINED IN E Q .(38) VIA SGM S AND VSDM WITH DIFFERENT βmax ’ S . SGM WITH βmax = 10
(SGM-10) FAILS TO GENERATE DATA OF ANISOTROPIC SHAPES AND IS NOT REPORTED .
S TRAIGHTNESS (X / Y) S PIRAL -8Y C HECKERBOARD -6X
SGM-20 8.3 / 49.3 53.5 / 11.0
SGM-30 9.4 / 57.3 64.6 / 13.1
VSDM-20 6.3 / 45.6 49.4 / 7.4
VSDM-10 5.5 / 38.7 43.9 / 6.5
Figure 10. Variational Schrödinger diffusion models (bottom) v.s. SGMs (top) with the same hyperparameters (βmax = 20) and six
function evaluations (NFE=6). Both models are generated by probability flow ODE.
Figure 11. Variational Schrödinger diffusion models (bottom) v.s. SGMs (top) with the same hyperparameters (βmax = 20) and eight
function evaluations (NFE=8). Both models are generated by probability flow ODE.
23
Variational Schrödinger Diffusion Models
Training. We adopt the encoder-decoder architecture as described in the main text, and change the decoder to either
our generative model or one of the competitors. The encoder is an LSTM with 2 layers and a hidden dimension size 64.
We train the model for 200 epochs, where each epoch takes 50 model updates. In case of our model we also alternate
between two training directions at a predefined rate. The neural network parameterizing the backward direction has the same
hyperparameters as in (Rasul et al., 2021), that is, it has 8 layers, 8 channels, and a hidden dimension of 64. The DDPM
baseline uses a standard setting for the linear beta-scheduler: βmin = 0.0001, βmax = 0.1 and 150 steps.
24