Skip to content

AutoDiff-Inference: Automatic Differentiation Inference. This Repository combines ADVI's change of variable power with Laplace Approximation to provide better inference for constrained parameters. Work done as a part of development of Bijax

License

Notifications You must be signed in to change notification settings

Madhav-Kanda/AutoDiff-Inference

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

16 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

AutoDiff-Inference (Bijax)

forthebadge made-with-python forthebadge

This repository contains code for implementing Automatic Differentiation Variational Inference (ADVI) and different variants of Laplace Approximation based on major research papers.

Features:

  • ADVI Implementation
  • Laplace Approximation: Implementation of Laplace Approximation for constrained variables, inspired by Automatic Differentiation Variational Inference (ADVI).

Laplace Approximation (LA)

Implementation

  ## Creation of the dataset for Laplace Approximation
  data_dist = tfd.Bernoulli(probs=0.7)
  data = data_dist.sample(sample_shape=(100,), seed=jax.random.PRNGKey(3))
  prior_theta = [3.0, 5.0]     
## Bernoulli likelihood function
def likelihood_fn(theta, data):
    return tfd.Bernoulli(probs=theta).log_prob(data).sum()

# For Posterior distribution
alpha = prior_theta[0] + data.sum()
beta = prior_theta[1] + len(data) - data.sum()

Normal Laplace Approximation

## Using Identity bijector for normal Laplace Approximation
   la = LaplaceApproximation(
    prior=tfd.Beta(prior_theta[0], prior_theta[1]),
    bijector=tfp.bijectors.Identity(),                  
    likelihood=likelihood_fn)
true_posterior = tfd.Beta(alpha, beta)      ## True posterior

fig = la.plot_approx_posterior(true_posterior=true_posterior)     

plt.xlim(-0.5,1.5)
plt.figure()
plt.savefig("plots/la_coin_toss.png")
image

Autodiff- Laplace Appoximation

## Using Sigmoid bijector for constrained Laplace Approximation
la_cov = LaplaceApproximation(
    prior=tfd.Beta(prior_theta[0], prior_theta[1]),
    bijector=tfp.bijectors.Sigmoid(),           
    likelihood=likelihood_fn)
true_posterior = tfd.Beta(alpha, beta)

fig_cov = la_cov.plot_approx_posterior(true_posterior=true_posterior)
plt.figure()
plt.savefig("plots/la_cov_coin_toss.png")

fig = la_cov.plot_log_approx_posterior(true_posterior=true_posterior)
plt.savefig("plots/log_la_cov_coin_toss.png")
image



In addition to the implemented library for Laplace approximation, you'll find two additional notebooks showcasing diagonal Laplace approximation and low-rank Laplace approximation.


Automatic Differentiation Variational Inference (ADVI)

Implementation

tfd = tfp.distributions
data_dist = tfd.Bernoulli(probs=0.7)
data = data_dist.sample(sample_shape=(100,), seed=jax.random.PRNGKey(3))
prior_theta = [3.0, 5.0]
def likelihood_fn(theta, data):
    return tfd.Bernoulli(probs=theta).log_prob(data).sum()
advi = ADVI(
    prior=tfd.Beta(prior_theta[0], prior_theta[1]),
    bijector=tfp.bijectors.NormalCDF(),
    likelihood=likelihood_fn,
)
appx_post = advi.approx_posterior(data)
image

About

AutoDiff-Inference: Automatic Differentiation Inference. This Repository combines ADVI's change of variable power with Laplace Approximation to provide better inference for constrained parameters. Work done as a part of development of Bijax

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published