We can’t directly minimize this term because we don’t know the log evidence which is encompassed in the first term, but we can indirectly minimize the KL by maximizing the evidence lower bound (ELBO):
There are many choices of approximating distribution. The simplest is a mean-field approximation for latent variables:
PyFlux currently implements this approximation with Gaussian distributions. Future versions will be more flexible in the choices of distribution and dependency structure. BBVI is a general-purpose methodology for conducting variational inference. Traditionally VI required quite specific requirements (such as exponential conjugacy), but BBVI by Ranganath (2013) allows for non-conjugate models. The basic idea is to approximate the expectations with Monte Carlo approximation. We simulate from the variational approximation and compute the gradient, then use a stochastic optimization technique to follow the gradient to maximize the ELBO. The gradient is:
Simulating from , the naive Monte Carlo estimator is:
But this naive estimator has a very high variance so variance reduction techniques are needed in practice. Ranganath (2013) implements Rao-Blackwellisation and control variates to reduce variance and uses AdaGrad for stochastic optimization to follow the gradient. In the PyFlux implementation for time series models we use control variates, following Ranganath, as well as Rao-Blackwellisation for some model types. The user has a choice of optimizer : RMSProp or ADAM. The user can also choose the batch_size of gradients – the default is which is relatively high to ensure smoother convergence. The final latent variable estimates are taken as an average over the final ten percent of iterations.
PyFluxWe will illustrate using BBVI to estimate a GASNormal model. We use the following dataset.
import numpy as np import pyflux as pf import pandas as pd import matplotlib.pyplot as plt %matplotlib inline data = pd.read_csv("https://vincentarelbundock.github.io/Rdatasets/csv/MASS/drivers.csv") accidents = data['drivers'].values plt.figure(figsize=(15,5)) plt.plot(data['time'].values,accidents) plt.ylabel('Deaths') plt.xlabel('Year') plt.title("Deaths of Car Drivers in Great Britain 1969-84");
Since we are dealing with large counts, a Normal model should work fine.
model = pf.GAS(data=accidents,ar=3,sc=3,family=pf.GASNormal()) model.latent_variables.adjust_prior(0,prior=pf.Uniform()) model.latent_variables.adjust_prior([1,4],prior=pf.Normal(0,3)) model.latent_variables.adjust_prior([2,5],prior=pf.Normal(0,2.70)) model.latent_variables.adjust_prior([3,6],prior=pf.Normal(0,2.40)) print(model.latent_variables)
Index Latent Variable Prior Prior Latent Vars V.I. Dist Transform ======== ========================= =============== ========================= ========== ========== 0 Constant Uniform n/a (non-informative) Normal None 1 AR(1) Normal mu0: 0, sigma0: 3 Normal None 2 AR(2) Normal mu0: 0, sigma0: 2.7 Normal None 3 AR(3) Normal mu0: 0, sigma0: 2.4 Normal None 4 SC(1) Normal mu0: 0, sigma0: 3 Normal None 5 SC(2) Normal mu0: 0, sigma0: 2.7 Normal None 6 SC(3) Normal mu0: 0, sigma0: 2.4 Normal None 7 Normal Scale Uniform n/a (non-informative) Normal exp
We next run BBVI on this model. We will use the RMSProp optimizer.
x = model.fit('BBVI',iterations=1000,optimizer='RMSProp') x.summary()
10% done : ELBO is -1777.70372532 20% done : ELBO is -1755.17341514 30% done : ELBO is -1720.58343452 40% done : ELBO is -1680.33880128 50% done : ELBO is -1629.45634868 60% done : ELBO is -1628.94228531 70% done : ELBO is -1628.91948274 80% done : ELBO is -1628.88687975 90% done : ELBO is -1628.83390833 100% done : ELBO is -1628.72716933 Final model ELBO is -1628.78277093 Normal GAS(3,0,3) ======================================================= ================================================ Dependent Variable: Series Method: BBVI Start Date: 3 Unnormalized Log Posterior: -1611.9611 End Date: 191 AIC: 3239.92219832 Number of observations: 189 BIC: 3265.85617444 ======================================================================================================== Latent Variable Median Mean 95% Credibility Interval ======================================== ================== ================== ========================= Constant 1670.3274 1670.3279 (1670.2466 | 1670.409) AR(1) -0.2844 -0.2843 (-0.3615 | -0.2068) AR(2) 0.2095 0.2094 (0.1269 | 0.2904) AR(3) 0.0897 0.0901 (0.0125 | 0.1686) SC(1) 0.6276 0.6278 (0.5465 | 0.7083) SC(2) 0.682 0.6818 (0.6016 | 0.7619) SC(3) 0.2457 0.2457 (0.1648 | 0.3256) Normal Scale 81.7985 81.7583 (75.2233 | 88.8316) ========================================================================================================
We can plot the latent variables below with plot_z:
We can check the in-sample fit of the model using plot_fit:
We can predict forward with the model and plot the results using plot_predict: