Student Perspectives: Strategies for variational inference in non-conjugate problems

A post by Qi Chen, PhD student on the Compass programme.


Introduction

Variational inference is a method to approximate posterior distributions. In Bayesian statistics context, we would like to get access to the posterior distribution \[p(\theta|x) = \frac{p(x|\theta)p(\theta)}{\int_\mathcal{\theta} p(x|\theta)p(\theta) d\theta}\]
In most cases the denominator $p(D)$ is intractable, that is we can not compute it analytically. How should we proceed? There are two broad ways:

  • Using MCMC to simulate samples from the posterior distribution $p(\theta|D)$ to approximate the true posterior and get statistics of interest(mean, variance, etc.).
  • Approximate $p({\theta}|{x})\approx q(\theta)\in\mathcal{Q}$.

The former method is unbiased and the convergence is guaranteed by the law of large numbers. But it requires a large number of samples and is quite computational demanding if the dimension of parameters/dataset is large. The later one, called variational inference,is biased depends on the choice of $\mathcal{Q}$ but is much faster and more scalable.

We call $q$ the variational distributions. The idea behind variational inference, is to approximate the posterior $p({\theta}|{x})$ using $q({\theta})\in\mathcal{Q}$ by minimizing the KL divergence between $q({\theta})$ and the true posterior $p({\theta}|{x})$, with the following formal expression:\[q^*({\theta}) = argmin_{q\in\mathcal{Q}}\;KL(q({\theta})||p({\theta}|{x})) = \int_{\Theta} q({\theta})\log\left(\frac{q({\theta})}{p({\theta}|{x})}\right)d{\theta}\]
This is a traditional measure of distribution mismatch over the same domain, and it is easy to see that $q = p$ is equivalent to $KL(q||p)=0$.

There are broadly two questions we would like to answer:

  • How do we minimize $q$ over the space with the true posterior unknown?
  • How do choose the variational family $q$?

We now answer the first question:
Notice that \begin{align*}
\log p({x}) &= \int q({\theta})\log\left(\frac{p({x},\theta)q({\theta})}{p({\theta}|{x})q({\theta})}\right) d{\theta}\\
&= \int q({\theta}) \log\left(\frac{p({x},{\theta})}{q({\theta})}\right) d{\theta} + \int q({\theta})\log\left(\frac{q({\theta})}{p({\theta}|{x})}\right)d{\theta}
\end{align*}
From the above derivation, we see that the second part is simply just the KL divergence we wish to minimize. As $\log(p({x}))$ is fixed, minimizing KL divergence is equivalent to maximizing the first part. This answers the first question. The first part is called \textbf{evidence lower bound(ELBO)}, written in $\mathcal{L}(q({\theta}))$.

For the second question, in theory, suppose the variational distribution is parametrized by variational parameters ${\phi}$, we can start with any variaitonal distributions we like, following the basic criterion:
Supp($q({\theta};{\phi})\subseteq$Supp($p({\theta}|{x})$).
We also need Supp($p({\theta}|{x})\subseteq$Supp($p({\theta})$) which is guaranteed in most cases.

But randomly choosing some variational distributions with any model won’t make the algorithm always feasible. Indeed, all VI methods centered around the goal of optimizing the ELBO \[\phi^* = argmin_{{\phi}}\mathbb{E}_q\left[\log \frac{p({x},{\theta})}{q({\theta})}\right]\]

Traditional methods set the mean-field assumptions that all parameters are independent. This breaks down the objective and a local optimum could be achieved via a coordinate ascent algorithm. Some methods enlarge the mean-field space to some specific conditional dependences between parameters according to graphical models with conjugate exponential relationship between parent-child pairs[1]. This is further extended to non-conjugate pairs with custom approximations.

Some modern methods have been developed in the last decade based on the idea that the gradient of the ELBO could be expressed in the from of $\mathbb{E}_q(\cdot)$. This immediately brings attention to a combination of MC algorithms(for sampling from $q$) and stochastic gradient descent(for efficiency in the optimization). These methods benefits from the simplicity that there’s no need to analytically compute the gradients based on conditional dependence specifications for each model: it is an automatic algorithm, for a greater domain of models. But it is worth noting that even those methods are theoretically sound, they still face practical issue which I will show in the later sections.

In this post I will briefly go through some of these methods, specifically coordinate ascent variational inference, black-box variational inference and automatic differentiation variational inference.

Conjugate models: Coordinate Ascent Variational Inference

There are various assumptions we can make on $\mathcal{Q}$ . We start with the mean-field assumptions of the parameters [2] This is to assume the joint prior distributions of all parameters could be factorized completely. That is:\[q({\theta}) = \prod_{j=1}^m q_j({\theta}_j)\]
We now write ${\theta}_{-j}$ denote all the other latent variables except for ${\theta}_j$, with distribution $q_{-j}$.
If we only minimize $\mathcal{L}(q)$ against $q_j({\theta}_j)$, we are minimizing \[\mathbb{E}_{q_j}[\mathbb{E}_{q_{-j}}[\log p({\theta},{x})]] – \mathbb{E}_{q_j}[\log q_j({\theta}_j)]\]

Further write $r_j({\theta}_j) = \frac{1}{Z_j}\exp\{\mathbb{E}_{q_{-j}}[\log p({x},{\theta})]\}$ where $Z_j$ is some normalizing constant so that $r_j$ is a probability distribution. Then substitute in, we get \[\mathcal{L}(q_j) \propto \mathbb{E}_{q_j}[\log \frac{r_j({\theta}_j)}{q_j({\theta}_j)}] = -KL(q_j({\theta}_j)||r_j({\theta}_j))\]

Thus maximizing ELBO against $q_j$ is equivalent to set $q_j = r_j$, which is \[q_j({\theta}_j)\propto \exp\{\mathbb{E}_{q_{-j}}[\log p({x},{\theta})]\}\propto \exp\{\mathbb{E}_{q_{-j}}[\log p({\theta}_j|{\theta}_{-j},{x})]\}\]

Since we assume $q$ factorizes, maximizing $\mathcal{L}(q)$ is split into $m$ steps of maximizing $\mathcal{L}(q_j)$. This algorithm is called \textbf{coordinate ascent variational inference}(CAVI) or \textbf{block-coordinate assent}.

A algorithmic view is

  1.  Initialize $q({\theta}) = \prod_{j=1}^m q_j({\theta}_j)$
  2. Iterate until convergence:
    Update for each $q_j$ by $q_j = \frac{1}{Z_j}\exp(\mathbb{E}_{q_{-j}}[\log(p({\theta},{x}))])$This algorithm is guarantee to convergence since each iteration the ELBO increases.

This is not directly feasible for all cases, since we assume we can compute $r_j$ analytically. In case where there’s conditional conjugacy of likelihood and the prior on each $\theta_j$ conditioned on all other ${\theta}_{i\neq j}$. That is \[p(\theta_j|{\theta}_{i\neq j})\in \mathcal{A}(\alpha),\,p({x}|\theta_j, \theta_{i\neq j})\in \mathcal{B}(\theta_j)\rightarrow p(\theta_j|{x},{\theta}_{i\neq j})\in A(\alpha’)\]
this will be feasible. One particular family is that all complete conditionals lie in exponential family.

A distribution $p({\theta})$ is in exponential family if \[p(\theta) = h({\theta})\exp\{{\eta}^Tt({\theta}) – A({\eta})\}\]
Here $\eta$ is called natural parameter, and $A({\theta})$ satisfies \[A({\eta}) = \log \int h({\theta}) \exp \eta^Tt(\theta)d{\theta}\]
such that it integrates to 1.

Now assume that all the complete conditionals belong to an exponential family distribution, that is \[p({\theta}_j|{\theta}_{-j},{x}) = h({\theta}_j)\exp \{{\eta}_j^T({\theta}_{-j},{x}){\theta}_j – A({\eta}_j({\theta}_{-j},{x}))\}\]
where we assume that ${\theta}_j$ is already transformed to its appropriate sufficient statistic. We see now the CAVI becomes \begin{align*}
q_j({\theta}_j)&\propto \exp\{\log h({\theta}_j) + \mathbb{E}_{q_{-j}}[{\eta}_j({\theta}_{-j},{x})]^T{\theta}_j – \mathbb{E}_{q_{-j}}[A({\eta}_j({\theta}_{-j},{x}))]\}\\
&\propto h({\theta}_j)\exp\{\mathbb{E}_{q_{-j}}[{\eta}_j({\theta}_{-j},{x})]^T{\theta}_j\}
\end{align*}
where we see that the variational factors are in the same exponential family(due to conjugacy) as the complete conditionals, with the natural parameter updated to \[\phi_j = \mathbb{E}_{q_{-j}}[{\eta}_j({\theta}_{-j},{x})]\]

But in most cases, for example Bayesian logistic regression, we do not have conditional conjugacy in our model. In this blog post, we introduce two methods which are developed in the last decade tackling the lack of conjugacy. Notice that variational inference is indeed an optimization problem, and these methods are derived from expressing the derivatives of the ELBO in terms of expectation over the vatiational distributions q: \[\frac{\partial ELBO}{\partial {\phi}} = \mathbb{E}_{q({\theta};{\phi})}[\cdot]\]
\section{Evaluable Models: Black Box Variational Inference}

We want to optimize \[\mathcal{L}({\phi}) = \mathbb{E}_{q}[\log p({\theta},{x})] – \mathbb{E}_q[\log q({\theta};{\phi})]\]
and we notice that
\begin{align*}
\triangledown_{{\phi}}\mathcal{L}({\phi}) &= \triangledown_{{\phi}}\int q({\theta};{\phi})\log \frac{p({\theta},{x})}{q(\theta;{\phi})} d{\theta}\\
&= \int q({\theta};{\phi})\triangledown_{{\phi}} \log q({\theta};{\phi})\log \frac{p({\theta},{x})}{q({\theta};{\phi})} + q({\theta};{\phi})\triangledown_{{\phi}}\log \frac{p({\theta},{x})}{q({\theta};{\phi})} d{\theta}\\
&= \mathbb{E}_{q}[\triangledown_{{\phi}}\log q({\theta};{\phi})(\log p({\theta},{x})-\log q({\theta};{\phi}))]
\end{align*}
This is proposed in [3]. We see this is an expectation under the variational distributions, and we only need

  • simulate from $q$.
  • evaluate the derivatives of $q$.
  • evaluate the model $p({\theta},{x})$.

This significantly relaxes the constraint of CAVI and enlarges the domain of models applicable.
In practice, we will use stochastic gradient descent to derive a noisy unbiased estimator of the gradient and adapt some step functions satisfying some conditions, for example \[\sum_j \rho_j =\infty\;\;\;\;\sum_j \rho_j^2 < \infty\]
A naive algorithm is as follows:

  • $t \gets 0$, $\delta \gets \infty$
  • While{$\delta > \tau$}{
    • $t \gets t+1$
    • ${\theta}^1,…,{\theta}^S\sim q({\theta},{\phi}_{t-1})$
    • $\hat{\triangledown}_{{\phi}}\mathcal{L}({\phi}_{t-1})\gets \frac{1}{S}\sum_{s=1}^S \triangledown_{{\phi}}\log q({\theta}^s;{\phi}_{t-1})(\log p({\theta}^s,{x})-\log q({\theta}^s;{\phi}_{t-1}))$
    • ${\phi}_t\gets{\phi}_{t-1} + \rho_t\hat{\triangledown}_{{\phi}}\mathcal{L}({\phi}_{t-1})$
    • $\delta \gets \frac{||{\phi}_t – {\phi}_{t-1}||}{||{\phi}_{t-1}||}$

}
Output{${\phi}^* = {\phi}^t$}

However, in practice, this algorithm does not produce meaningful result for non-trivial model, since the variance of this estimates grows linearly with the number of parameters in the model ${\theta}$. Due to the high variance, we need some variance reduction technique.

Rao-Blackwellization

Rao-Blackwellization reduces the variance of some estimator $J(X,Y)$ by defining another estimator \[\hat{J}(X) = \mathbb{E}[J(X,Y)|X]\]
It is clear that the expectation is preserved:\[\mathbb{E}[\hat{J}(X)] = \mathbb{E}[J(X,Y)]\]by tower law. The variance of this estimator is \[Var(\hat{J}(X)) = Var(J(X,Y)) + \mathbb{E}[\hat{J}(X)^2] – \mathbb{E}[J(X,Y)^2] = Var(J(X,Y)) – \mathbb{E}[(J(X,Y)-\hat{J}(X))^2]\]

Thus this new estimator always has less variance compared to $J(X,Y)$ unless $\hat{J}(X) = J(X,Y)$.

We now apply this to BBVI. Assume the approximating family follows the mean-field assumption, and let $p({x},{\theta}) = p_i({x},{\theta}_{(i)})p_{-i}({x},{\theta}_{-i})$
where $p_i$ are all the terms containing $\theta_i$, and $\theta_{(i)}$ is the collection of all latent variables that appear in $p_i$.
We can thus rewrite the derivatives of ELBO respect to $\theta_i$ as \[\hat{\triangledown}_{\phi_i}^{RB}\mathcal{L}(\phi_i) = \mathbb{E}_{q_{(i)}}[\triangledown_{\phi_i}[\log q_i(z_i;\phi_i)(\log p_i({x},\theta_{(i)})-\log q_i(\theta_i;\phi_i))]]\]
This is a Rao-Blackwellized $\triangledown_{\phi_i}\mathcal{L}({\phi})$ as \[\mathbb{E}_q[\hat{\triangledown}_{\phi_i}\mathcal{L}({\phi}) – \hat{\triangledown}_{\phi_i}^{RB}\mathcal{L}(\phi_i)] = C\mathbb{E}_{q_i}[\triangledown_{\phi_i}[\log q_i(\theta_i;\phi_i)]] = 0\]
with \[C = \mathbb{E}_{q_{-i}}[\log p_{-i}({x},{\theta}_{-i})] – \mathbb{E}_{q_{-i}}[\sum_{j\neq i}\log q_j(\theta_j;\phi_j)]\]
The detailed derivation could be found in [3].

Control variates

We now introduce another method using regression estimator. Suppose we want to estimate some parameter $\mu$ and we have an estimator $f$ with $\mathbb{E}[f(u)] = \mu$, u is a random variable. Furthermore, if we have a “similar” function $h$ such that $\mathbb{E}[h(u)] = \nu$ is known. Then we define a new estimator of $\mu$:\[g(u) = f(u)-\beta(h(u)-\nu)\]

This is clearly an unbiased estimator and for the variance term\[Var(g(u)) = Var(f(u)) + \beta^2 Var(h(u)) – 2\beta Cov(f(u),h(u))\]

In order to minimize this variance, we choose \[\hat{\beta} = \frac{Cov(h(u),f(u))}{Var(h(u))}\]

This is also the OLS estimator for the linear regression:\[f(u) = \mu + \beta(h(u)-\nu)\] Now plugging in this $\hat{\beta}$ we have \[Var(g(u)) = Var(f(u))(1-\rho^2_{fh})\] where $\rho^2_{fh}$ is the correlation between $f(u)$ and $h(u)$. Such $h$ is called the control variate.

Improved BBVI

The original author in [3] combined these two methods and choose $\triangledown_{\phi_i}\log q_i(\theta_i;\phi_i)$ as the control variate for $\hat{\triangledown}_{\phi_i}^{RB}\mathcal{L}(\phi_i)$, which is shown below:

  • $t \gets 0$, $\delta \gets \infty$\
  • While{$\delta > \tau$}{
    • t \gets t+1$
    • ${\theta}^1,…,{\theta}^S\sim q({\theta},{\phi}_{t-1})$
    • For{$i\gets 1$to $n$}{
      • $f_i \gets \frac{1}{S}\sum_{s=1}^S \triangledown_{\phi_i}\log q(\theta_i^s;{\phi}^{t-1}_{i})(\log p_i({\theta}_{(i)}^s,{x})-\log q_i(\theta_i^s;{\phi}_i^{t-1}))$
      • $h_i\gets \frac{1}{S}\sum_{s} \triangledown_{\phi_i}[\log q_i(\theta_i^s;{\phi}_i^{t-1}))]$
      • $\hat{\beta}_i \gets \frac{\hat{Cov}(f_i,h_i)}{\hat{Var}(h_i)}$
      • $g_i \gets f_i-\hat{\beta}h_i$
      • $\phi_i^t\gets \phi_i^{t-1} + \rho_tg_i$
        }
    • $\delta \gets \frac{||{\phi}_t – \phi_{t-1}||}{||{\phi}_{t-1}||}$

}

  • Output{${\phi}^* = {\phi}^t$}

Final Conclusion for BBVI

According to the same authors in [4], they pointed out the limitation of BBVI. They found that the gradient can be very unstable for large values of their inputs, and adaptive step-size like AdaGrad needs extra tunning. Also, they found that, in the case of linear mixed effects model, it under-performs MH-Gibbs sampler. Also, they did experiment in LDA(Latent Dirichlet allocation), Gibbs sampler converged in couple of minutes for 20 topics but BBVI does not produce any reasonable results after hours of iterations for 2 topics. Thus, it requires more experiments and BBVI still has practical limitations.

Differentiable Models: Automatic Differentiation Variational Inference

The idea behind Automatic Differentiation Variational Inference(ADVI) is as follows

  • Transform the parameter space to real space: $T:Supp({\theta})\rightarrow\mathbb{R}^k$ by a one-to-one mapping.
  • Let ${\psi} = T({\theta})$ a joint normal distribution. That is \[q({\psi}|{\phi}) \sim \mathcal{N}({\mu},\Sigma)\] Notice that we need to ensure $\Sigma$ to be full rank. One way to do that is using Cholesky factorization: $\Sigma = LL^T$ where $L$ is a lower triangular matrix with dimension $(k+1)k/2$. Overall, ${\phi}$ lives in $\mathbb{R}^{(k+1)k/2+k}$ where $k$ is the dimension of parameters in our model. This comes with computational cost, so we may wish to make a mean-field assumption to ${\psi}$
  • Finally we make the standardization ${\eta} = S_{{\phi}}({\psi}) = L^{-1}({\psi}-{\mu})$. This makes $q({\eta}) = \mathcal{N}({\eta};{0},{I})$.

Following the above recipe, we can rewrite the ELBO as \[{\phi}^* = argmin_{\phi} \mathbb{E}_{\mathcal{N}({\eta};{0},{I})}\left[\log p\left({x},T^{-1}(S^{-1}_{{\phi}}({\eta}))\right) + \log |detJ_{T^{-1}}(S_{{\phi}}^{-1}({\eta}))|\right] + \mathbb{H}[q({\psi};{\phi})]\]
In this case, the variational parameters are contained in the transformation $S$. We now give the gradients:\[\triangledown_{{\mu}}\mathcal{L} = \mathbb{E}_{\mathcal{N}({\eta})}[\triangledown_{{\theta}}\log p({x},{\theta})\triangledown_{{\psi}}T^{-1}({\psi}) + \triangledown_{{\psi}}\log|detJ_{T^{-1}}({\psi})|]\]
and \[\triangledown_{L}\mathcal{L} =\mathbb{E}_{\mathcal{N}({\eta})}[\left(\triangledown_{{\theta}}\log p({x},{\theta})\triangledown_{{\psi}}T^{-1}({\psi}) + \triangledown_{{\psi}}\log|detJ_{T^{-1}}({\psi})|\right){\eta}^T] + (L^{-1})^T\]
Now similar to BBVI, we can use MC algorithm and SGD to get an approximate gradient and do gradient descent. In [5] they propose a gradient of the form
\[\rho_k^i = \eta\times i^{-1/2+\epsilon}\times\left(\tau + \sqrt{s_k^i}\right)^{-1}\]
where \[s_k^i = \alpha (g_k^i)^2 + (1-\alpha)s_k^{i-1}\]
Here $k$ is the kth element and $i$ is the ith iteration. $g_k^i$ is the gradient vector at iteration i, and $s_k^1 = (g_k^1)^2$

Notice that here $\eta$ is another variable controls the scale of the step size sequence, it could be searched among $\{0.001,0.1,1,10,100\}$. $\epsilon$ is set to be small, for example $\epsilon = 10^{-6}$, to satisfy the Robbins and Monro conditions. The last term is to keep the memory of the past gradients. More details could be found in [5].

 

It is shown that in ADVI, variance of estimates of the gradients is controled better compared to BBVI. The performance is also compared to those famous MC methods, result is also displayed below.

 

 

[1] John Winn and Christopher M. Bishop. Variational message passing. Journal of Machine Learning Research, 6(23):661–694, 2005.

[2] David M. Blei, Alp Kucukelbir, and Jon D. McAuliffe. Variational inference: A review for statisticians. Journal of the American Statistical Association, 112(518):859–877, apr 2017.

[3] Rajesh Ranganath, Sean Gerrish, and David M. Blei. Black box variational inference, 2013.

[4] Rajesh Ranganath, Sean Gerrish, and David Blei. Black Box Variational Inference. In Samuel Kaski and Jukka Corander, editors, Proceedings of the Seventeenth International Conference on

Artificial Intelligence and Statistics, volume 33 of Proceedings of Machine Learning Research, pages 814–822, Reykjavik, Iceland, 22–25 Apr 2014. PMLR.

[5] Alp Kucukelbir, Dustin Tran, Rajesh Ranganath, Andrew Gelman, and David M. Blei. Automatic differentiation variational inference. J. Mach. Learn. Res., 18(1):430–474, jan 2017.

 

Skip to toolbar