Changes in EP/EPDTC to fix numerical issues and increase the flexibility of the inference.

Changes to avoid numerical issues and improve the performance:
    - Keep value of the EP parameters between calls
    - Enforce positivity of tau_tilde
    - Stable computation of the EP moments for the Bernoulli likelihood
    - Compute marginal in the GP model without directly inverting tau_tilde

    Changes to improve the flexibility:
    - Add parameter for maximum number of iterations
    - Distinguish between alternated/nested mode
    - Distinguish between sequential/parallel updates in EP
This commit is contained in:
Moreno 2017-02-17 11:35:30 +00:00
parent 627c878455
commit 0c248e7520
7 changed files with 522 additions and 117 deletions

View file

@ -2,7 +2,7 @@
# Licensed under the BSD 3-clause license (see LICENSE.txt)
import numpy as np
from ..util.univariate_Gaussian import std_norm_pdf, std_norm_cdf
from ..util.univariate_Gaussian import std_norm_pdf, std_norm_cdf, derivLogCdfNormal, logCdfNormal
from . import link_functions
from .likelihood import Likelihood
@ -59,24 +59,24 @@ class Bernoulli(Likelihood):
raise ValueError("bad value for Bernoulli observation (0, 1)")
if isinstance(self.gp_link, link_functions.Probit):
z = sign*v_i/np.sqrt(tau_i**2 + tau_i)
Z_hat = std_norm_cdf(z)
Z_hat = np.where(Z_hat==0, 1e-15, Z_hat)
phi = std_norm_pdf(z)
phi_div_Phi = derivLogCdfNormal(z)
log_Z_hat = logCdfNormal(z)
mu_hat = v_i/tau_i + sign*phi/(Z_hat*np.sqrt(tau_i**2 + tau_i))
sigma2_hat = 1./tau_i - (phi/((tau_i**2+tau_i)*Z_hat))*(z+phi/Z_hat)
mu_hat = v_i/tau_i + sign*phi_div_Phi/np.sqrt(tau_i**2 + tau_i)
sigma2_hat = 1./tau_i - (phi_div_Phi/(tau_i**2+tau_i))*(z+phi_div_Phi)
elif isinstance(self.gp_link, link_functions.Heaviside):
a = sign*v_i/np.sqrt(tau_i)
Z_hat = np.max(1e-13, std_norm_cdf(z))
N = std_norm_pdf(a)
mu_hat = v_i/tau_i + sign*N/Z_hat/np.sqrt(tau_i)
sigma2_hat = (1. - a*N/Z_hat - np.square(N/Z_hat))/tau_i
z = sign*v_i/np.sqrt(tau_i)
phi_div_Phi = derivLogCdfNormal(z)
log_Z_hat = logCdfNormal(z)
mu_hat = v_i/tau_i + sign*phi_div_Phi/np.sqrt(tau_i)
sigma2_hat = (1. - a*phi_div_Phi - np.square(phi_div_Phi))/tau_i
else:
#TODO: do we want to revert to numerical quadrature here?
raise ValueError("Exact moment matching not available for link {}".format(self.gp_link.__name__))
return Z_hat, mu_hat, sigma2_hat
# TODO: Output log_Z_hat instead of Z_hat (needs to be change in all others likelihoods)
return np.exp(log_Z_hat), mu_hat, sigma2_hat
def variational_expectations(self, Y, m, v, gh_points=None, Y_metadata=None):
if isinstance(self.gp_link, link_functions.Probit):