Skip to main content

[Optimization] Variational Inference

· 7 min read

Variational inference in a method to approximate the posterior p(zx)p(\mathbf{z}\vert\mathbf{x}). This is a key technique for Variational AutoEncoder, one of the most famous generative model.

Introduction

When we are dealing with posteriors, it is likely to be intractable because of its integration for normalizing constant. Historically, people used Monte Carlo(MC) integration to approximate it. However, MC approximation is costly.

Variational Inference(VI) was devised to replace complex approximation with optimization problems. The main benefit of VI is that we don't need to know normalizing constant. This enables us to calculate posterior p(zx)p(z \vert x) because integration for normalizing constant is one of the most significant problems. Although VI is mainly used to calculate approximate posterior, it can be used to approximate any distribution whose normalizing constant is unknown. VI defines a family of distributions, called variational family, and optimizes a "distribution difference measure", called KL-divergence.

  • We can transform the approximation task into optimization task by parametrizing the variational distributions.
  • We don't need normalizing constant because we use KL-divergence
info

VI is used when we have a distribution which we don't know completely but only know up to normalizing constant. We define a parametrized distribution and optimize KL divergence.

Key Terminologies

Before we delve into the Variational Inference(VI), it would be better to clarify some keywords: inference, evidence, evaluation, prediction, variational, and learning.

What is Statistical Inference?

Statistical inference is the process of using a sample to infer the properties of a population. In [[Probabilistic Model#Latent Variable Model|latent variable model]], we believe that there are some latent variables that control the events. Therefore, inferring underlying properties can be interpreted as updating latent variables(z\mathbf{z}) based on observations(x\mathbf{x}). Bayes' rule tell us how to update the information on latent variables.

p(zx)=p(x,z)dzp(x,z)\begin{align} p(\mathbf{z} \vert \mathbf{x}) = \frac{p(\mathbf{x},\mathbf{z})}{\int d \mathbf{z} p(\mathbf{x},\mathbf{z})} \end{align}

To assess the results of modeling and inference, we would like to know how well a model fits observed data xx. We can quantify the fitness between a model and observed data via marginal likelihood or evidence .

p(x)=p(x,z)dz\begin{align} p (\mathbf{x}) &= \int p(\mathbf{x},\mathbf{z}) d\mathbf{z} \end{align}

Once the model parameter θ\theta is determined, we can make predictions for new data with the posterior predictive distribution. This is acquired by replacing the prior on the latent variable zz with the posterior.

p(xnew)=zp(xnewz)p(z)dzp(xnewx)=zp(xnewz)p(zx)dz\begin{align} p(x_{new}) &= \int_{\mathbf{z}} p(x_{new} \vert \mathbf{z}) p(\mathbf{z}) d\mathbf{z}\\ p(x_{new} \vert \mathbf{x}) &= \int_{\mathbf{z}} p(x_{new} \vert \mathbf{z}) p(\mathbf{z} \vert \mathbf{x}) d\mathbf{z} \end{align}

Why Variational?

In functional analysis, functional is a general mapping from a space XX into a single (real/complex) value. In this definition, the domain XX can be a space of functions. If it is the case, the functional is a function that takes another function as input and output a single value.

Many problems involve finding an optimal input function that maximizes/minimizes the functional. The mathematical techniques developed to solve this type of problem are collectively known as the calculus of variations.

VI defines a KL divergence and it is a functional because it takes two functions(probability distributions) as input and returns a single value(difference between them.) Then, VI finds the optimal input function(variational distribution) which minimizes the KL-divergence. Therefore, this process involves kind of calculus of variation. This is why we call it Variational inference.

Learning and Parametrized Model

We introduced probabilistic model for ML problems, and [[Probabilistic Model#Parametrized Data Distribution|parametrized the model]] to learn from the data. With the parametrized model, learning can be defined as a process to find an optimal model parameter θ\theta from observed data. This is optimizing θ\theta to maximize the evidence. For the ease of computation, we maximize log evidence.

θmax=argmaxθpθ(x)=argmaxθlogpθ(x)\theta_{max} = \arg \max_\theta p_\theta(\mathbf{x}) = \arg \max_\theta \log p_\theta(\mathbf{x})

VI is a general approach to calculate an approximate posterior, which can be applied whether the probabilistic model is parameterized or not. However, it would be better to keep the parametrized notation for further analyses.

Model Param. vs Variational Param.
  • θ\theta is a model parameter which is introduced to define a probabilistic model.
  • ϕ\phi is a variational parameter which is introduced to define a variational family for VI.

Variational Inference

Variational inference is a method to calculate an approximate distribution qϕ(z)q_{\phi}(\mathbf{z}) of the posterior distribution pθ(zx)p_{\theta}(\mathbf{z} \vert \mathbf{x}) with fixed θ\theta. Since we want to get a good approximation, we want to minimize the difference(divergence) between pθ()p_{\theta}(\cdot) and qϕ()q_{\phi}(\cdot) by optimizing ϕ\phi. KL-divergence is a good metric to measure the difference between two probability distributions.

Objective: minimize KL divergence between surrogate distribution qϕ(z)q_{\phi}(\mathbf{z}) and posterior distribution pθ(zx)p_{\theta}(\mathbf{z}\vert \mathbf{x})

minϕKL(qϕ(z)pθ(zx))\min_{\phi} KL(q_{\phi}(\mathbf{z}) \| p_{\theta}(\mathbf{z}\vert \mathbf{x}))

Evidence Lower BOund (ELBO)

We cannot compute the objective KL(qϕ(z)pθ(zx))KL(q_{\phi}(\mathbf{z}) \| p_{\theta}(\mathbf{z}\vert \mathbf{x})) because calculating KL divergence requires pθ(zx)p_{\theta}(\mathbf{z}\vert \mathbf{x}) which is unknown. Let's rewrite the KL divergence.

KL(qϕ(z)pθ(zx))=Ezqϕ(z)[logqϕ(z)pθ(zx)]=Ezqϕ(z)[logqϕ(z)]Ezqϕ(z)[logpθ(zx)]=Ezqϕ(z)[logqϕ(z)]Ezqϕ(z)[logpθ(x,z)]+logpθ(x)=Ezqϕ(z)[logqϕ(z)pθ(x,z)]+logpθ(x)KL(qϕ(z)pθ(zx))=KL(qϕ(z)pθ(x,z))+logpθ(x)logpθ(x)=KL(qϕ(z)pθ(zx))0KL(qϕ(z)pθ(x,z))ELBO\begin{align} KL(q_\phi(\mathbf{z}) \| p_\theta(z\vert \mathbf{x})) &= \mathbb{E}_{\mathbf{z} \sim q_\phi(\mathbf{z})} \Big[ \log\frac{q_\phi(\mathbf{z})}{p_\theta(\mathbf{z} \vert \mathbf{x})} \Big] \\ &= \mathbb{E}_{\mathbf{z} \sim q_\phi(\mathbf{z})}[\log q_\phi(\mathbf{z})] - \mathbb{E}_{\mathbf{z} \sim q_\phi(\mathbf{z})}[\log p_\theta(\mathbf{z}\vert \mathbf{x})] \\ &= \mathbb{E}_{\mathbf{z} \sim q_\phi(\mathbf{z})}[\log q_\phi(\mathbf{z})] - \mathbb{E}_{\mathbf{z} \sim q_\phi(\mathbf{z})}[\log p_\theta(\mathbf{x},\mathbf{z})] + \log p_\theta(\mathbf{x}) \\ &= \mathbb{E}_{\mathbf{z} \sim q_\phi(\mathbf{z})}\Big[ \log \frac{q_\phi(\mathbf{z})}{p_\theta(\mathbf{x},\mathbf{z})}\Big] + \log p_\theta(\mathbf{x}) \\ KL(q_\phi(\mathbf{z}) \| p_\theta(z\vert \mathbf{x}))&= KL(q_\phi(\mathbf{z}) \| p_\theta(\mathbf{x},\mathbf{z})) + \log p_\theta(\mathbf{x})\\ \log p_\theta(\mathbf{x}) &= \underbrace{KL(q_\phi(\mathbf{z}) \| p_\theta(\mathbf{z} \vert \mathbf{x}))}_{\ge 0} \underbrace{- KL(q_\phi(\mathbf{z}) \| p_\theta(\mathbf{x},\mathbf{z}))}_{ELBO} \end{align}

We can decompose the KL(qϕ(z)pθ(zx))KL(q_\phi(\mathbf{z}) \| p_\theta(\mathbf{z}\vert \mathbf{x})) into KL(qϕ(z)pθ(x,z))KL(q_\phi(\mathbf{z}) \| p_\theta(\mathbf{x},\mathbf{z})) and likelihood logpθ(x)\log p_\theta(\mathbf{x}). Let's inspect each of the three terms.

  • KL(qϕ(z)pθ(zx))KL(q_\phi(\mathbf{z}) \| p_\theta(\mathbf{z}\vert \mathbf{x}))
    • Something that we want to minimize.
    • intractable because we don't know pθ(zx)p_\theta(z\vert \mathbf{x}).
    • This term is non-negative by definition of "KL-divergence"
  • ELBO
    • ELBO=KL(qϕ(z)pθ(x,z))=Ezqϕ(z)[logpθ(x,z)qϕ(z)]ELBO = -KL(q_\phi(\mathbf{z}) \| p_\theta(\mathbf{x},\mathbf{z}))=\mathbb{E}_{\mathbf{z}\sim q_\phi(\mathbf{z})}[\log p_\theta(\mathbf{x},\mathbf{z}) - q_\phi(\mathbf{z})]
    • This term is tractable.
      • In latent variable models, we assume that the complete data likelihood pθ(x,z)p_\theta(\mathbf{x},\mathbf{z}) is tractable.
      • By construction, qϕ(z)q_\phi(\mathbf{z}) is tractable
  • logpθ(x)\log p_\theta(\mathbf{x})
    • log marginal likelihood, called log evidence.
    • This term does not depend on qϕ(),ϕq_\phi(\cdot), \phi

For a fixed model parameter θ\theta, changing variational parameter ϕ\phi does not change evidence logpθ(x)\log p_\theta(\mathbf{x}). Therefore, maximizing ELBO by optimizing ϕ\phi is equivalent to minimizing KL(qϕ(z)pθ(zx))KL(q_\phi(\mathbf{z}) \| p_\theta(\mathbf{z}\vert \mathbf{x})). This means that qϕ()q_\phi(\cdot) approximates pθ(zx)p_{\theta}(\mathbf{z}\vert \mathbf{x}) well.

With fixed θargmaxϕELBO=argminϕKL(qϕ(z)pθ(zx))\begin{align} &With\ fixed\ \theta &\\ &\arg \max_{\phi} ELBO = \arg \min_{\phi} KL(q_{\phi}(\mathbf{z})\|p_{\theta}(\mathbf{z}\vert \mathbf{x})) & \end{align}

ELBO Interpretation

ELBO=Ezqϕ(z)[logpθ(x,z)qϕ(z)]=Ezqϕ(z)[logpθ(z)+logpθ(xz)qϕ(z)]=Ezqϕ(z)[logpθ(xz)]KL(qϕ(z)pθ(z))\begin{align} ELBO &= \mathbb{E}_{\mathbf{z}\sim q_\phi(\mathbf{z})}[\log p_\theta(\mathbf{x},\mathbf{z}) - q_\phi(\mathbf{z})] \\ &= \mathbb{E}_{\mathbf{z}\sim q_\phi(\mathbf{z})}[\log p_{\theta}(\mathbf{z}) + \log p_{\theta}(\mathbf{x} \vert\mathbf{z}) - q_\phi(\mathbf{z})] \\ &= \mathbb{E}_{\mathbf{z}\sim q_\phi(\mathbf{z})}[\log p_{\theta}(\mathbf{x} \vert\mathbf{z})] - KL(q_{\phi}(\mathbf{z}) \| p_{\theta}(\mathbf{z})) \end{align}

If we rewrite the ELBO, it consists of two terms

  • approximate log likelihood log likelihood = logpθ(x)=Ez[logpθ(xz)]Ezqϕ(z)[logpθ(xz)]\log p_{\theta}(\mathbf{x})=\mathbb{E}_{\mathbf{z}}[\log p_\theta(\mathbf{x}\vert \mathbf{z})] \approx \mathbb{E}_{\mathbf{z} \sim q_{\phi}(\mathbf{z})}[\log p_\theta(\mathbf{x}\vert \mathbf{z})].
  • KL divergence between approximate posterior qϕ(z)q_{\phi}(\mathbf{z}) and prior pθ(z)p_{\theta}(\mathbf{z}).

Therefore, maximizing ELBO is balancing between

  • maximizing log likelihood
  • minimizing the distance to the prior

In a plain text, we want to maximize ELBO up to its upper bound(log likelihood) while keeping the surrogate distribution close to the prior distribution.

Wrapup

  • VI is an approach to approximate a distribution whose normalizing constant is intractable.
  • When we are using a latent variable model, we assume that complete likelihood is tractable. The posterior pθ(zx)=pθ(x,z)pθ(x)p_{\theta}(\mathbf{z}\vert \mathbf{x})=\frac{p_{\theta}(\mathbf{x}, \mathbf{z})}{p_{\theta}(\mathbf{x})} is intractable because its denominator(normalizing constant) is intractable. Therefore, VI is appropriate for calculating posterior.
  • To apply VI, we define a set of variational distributions called variational family. This is a parametrized distribution qϕ(z)q_{\phi}(\mathbf{z}).
  • By maximizing ELBO, qϕ(z)q_{\phi}(\mathbf{z}) becomes a good approximation of pθ(zx)p_{\theta}(\mathbf{z}\vert \mathbf{x}).

Resources