mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-02 14:45:15 +02:00
Changes in EP/EPDTC to fix numerical issues and increase the flexibility of the inference.
Changes to avoid numerical issues and improve the performance:
- Keep value of the EP parameters between calls
- Enforce positivity of tau_tilde
- Stable computation of the EP moments for the Bernoulli likelihood
- Compute marginal in the GP model without directly inverting tau_tilde
Changes to improve the flexibility:
- Add parameter for maximum number of iterations
- Distinguish between alternated/nested mode
- Distinguish between sequential/parallel updates in EP
This commit is contained in:
parent
627c878455
commit
0c248e7520
7 changed files with 522 additions and 117 deletions
|
|
@ -2,7 +2,7 @@
|
|||
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
||||
|
||||
import numpy as np
|
||||
from ..util.univariate_Gaussian import std_norm_pdf, std_norm_cdf
|
||||
from ..util.univariate_Gaussian import std_norm_pdf, std_norm_cdf, derivLogCdfNormal, logCdfNormal
|
||||
from . import link_functions
|
||||
from .likelihood import Likelihood
|
||||
|
||||
|
|
@ -59,24 +59,24 @@ class Bernoulli(Likelihood):
|
|||
raise ValueError("bad value for Bernoulli observation (0, 1)")
|
||||
if isinstance(self.gp_link, link_functions.Probit):
|
||||
z = sign*v_i/np.sqrt(tau_i**2 + tau_i)
|
||||
Z_hat = std_norm_cdf(z)
|
||||
Z_hat = np.where(Z_hat==0, 1e-15, Z_hat)
|
||||
phi = std_norm_pdf(z)
|
||||
phi_div_Phi = derivLogCdfNormal(z)
|
||||
log_Z_hat = logCdfNormal(z)
|
||||
|
||||
mu_hat = v_i/tau_i + sign*phi/(Z_hat*np.sqrt(tau_i**2 + tau_i))
|
||||
sigma2_hat = 1./tau_i - (phi/((tau_i**2+tau_i)*Z_hat))*(z+phi/Z_hat)
|
||||
mu_hat = v_i/tau_i + sign*phi_div_Phi/np.sqrt(tau_i**2 + tau_i)
|
||||
sigma2_hat = 1./tau_i - (phi_div_Phi/(tau_i**2+tau_i))*(z+phi_div_Phi)
|
||||
|
||||
elif isinstance(self.gp_link, link_functions.Heaviside):
|
||||
a = sign*v_i/np.sqrt(tau_i)
|
||||
Z_hat = np.max(1e-13, std_norm_cdf(z))
|
||||
N = std_norm_pdf(a)
|
||||
mu_hat = v_i/tau_i + sign*N/Z_hat/np.sqrt(tau_i)
|
||||
sigma2_hat = (1. - a*N/Z_hat - np.square(N/Z_hat))/tau_i
|
||||
z = sign*v_i/np.sqrt(tau_i)
|
||||
phi_div_Phi = derivLogCdfNormal(z)
|
||||
log_Z_hat = logCdfNormal(z)
|
||||
mu_hat = v_i/tau_i + sign*phi_div_Phi/np.sqrt(tau_i)
|
||||
sigma2_hat = (1. - a*phi_div_Phi - np.square(phi_div_Phi))/tau_i
|
||||
else:
|
||||
#TODO: do we want to revert to numerical quadrature here?
|
||||
raise ValueError("Exact moment matching not available for link {}".format(self.gp_link.__name__))
|
||||
|
||||
return Z_hat, mu_hat, sigma2_hat
|
||||
# TODO: Output log_Z_hat instead of Z_hat (needs to be change in all others likelihoods)
|
||||
return np.exp(log_Z_hat), mu_hat, sigma2_hat
|
||||
|
||||
def variational_expectations(self, Y, m, v, gh_points=None, Y_metadata=None):
|
||||
if isinstance(self.gp_link, link_functions.Probit):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue