mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-21 14:05:14 +02:00
Fixed log predictive density, added option for LOO to provide some intemediate variables
This commit is contained in:
parent
fe0a4285ca
commit
361f0a5274
2 changed files with 22 additions and 13 deletions
|
|
@ -40,7 +40,7 @@ class Laplace(LatentFunctionInference):
|
||||||
self.first_run = True
|
self.first_run = True
|
||||||
self._previous_Ki_fhat = None
|
self._previous_Ki_fhat = None
|
||||||
|
|
||||||
def LOO(self, kern, X, Y, likelihood, posterior, Y_metadata=None, K=None):
|
def LOO(self, kern, X, Y, likelihood, posterior, Y_metadata=None, K=None, f_hat=None, W=None, Ki_W_i=None):
|
||||||
"""
|
"""
|
||||||
Leave one out log predictive density as found in
|
Leave one out log predictive density as found in
|
||||||
"Bayesian leave-one-out cross-validation approximations for Gaussian latent variable models"
|
"Bayesian leave-one-out cross-validation approximations for Gaussian latent variable models"
|
||||||
|
|
@ -51,13 +51,19 @@ class Laplace(LatentFunctionInference):
|
||||||
if K is None:
|
if K is None:
|
||||||
K = kern.K(X)
|
K = kern.K(X)
|
||||||
|
|
||||||
f_hat, _ = self.rasm_mode(K, Y, likelihood, Ki_f_init, Y_metadata=Y_metadata)
|
if f_hat is None:
|
||||||
W = -likelihood.d2logpdf_df2(f_hat, Y, Y_metadata=Y_metadata)
|
f_hat, _ = self.rasm_mode(K, Y, likelihood, Ki_f_init, Y_metadata=Y_metadata)
|
||||||
|
|
||||||
|
if W is None:
|
||||||
|
W = -likelihood.d2logpdf_df2(f_hat, Y, Y_metadata=Y_metadata)
|
||||||
|
|
||||||
|
if Ki_W_i is None:
|
||||||
|
_, _, _, Ki_W_i = self._compute_B_statistics(K, W, likelihood.log_concave)
|
||||||
|
|
||||||
logpdf_dfhat = likelihood.dlogpdf_df(f_hat, Y, Y_metadata=Y_metadata)
|
logpdf_dfhat = likelihood.dlogpdf_df(f_hat, Y, Y_metadata=Y_metadata)
|
||||||
|
|
||||||
K_Wi_i, _, _, Ki_W_i = self._compute_B_statistics(K, W, likelihood.log_concave)
|
if W.shape[1] == 1:
|
||||||
|
W = np.diagflat(W)
|
||||||
W = np.diagflat(W)
|
|
||||||
|
|
||||||
#Eq 14, and 16
|
#Eq 14, and 16
|
||||||
var_site = 1./np.diag(W)[:, None]
|
var_site = 1./np.diag(W)[:, None]
|
||||||
|
|
|
||||||
|
|
@ -114,21 +114,24 @@ class Likelihood(Parameterized):
|
||||||
#Otherwise just pass along None's
|
#Otherwise just pass along None's
|
||||||
zipped_values = zip(flat_y_test, flat_mu_star, flat_var_star, [None]*y_test.shape[0])
|
zipped_values = zip(flat_y_test, flat_mu_star, flat_var_star, [None]*y_test.shape[0])
|
||||||
|
|
||||||
def integral_generator(y, m, v, y_m):
|
def integral_generator(yi, mi, vi, yi_m):
|
||||||
"""Generate a function which can be integrated to give p(Y*|Y) = int p(Y*|f*)p(f*|Y) df*"""
|
"""Generate a function which can be integrated
|
||||||
def f(f_star):
|
to give p(Y*|Y) = int p(Y*|f*)p(f*|Y) df*"""
|
||||||
|
def f(fi_star):
|
||||||
#exponent = np.exp(-(1./(2*v))*np.square(m-f_star))
|
#exponent = np.exp(-(1./(2*v))*np.square(m-f_star))
|
||||||
#from GPy.util.misc import safe_exp
|
#from GPy.util.misc import safe_exp
|
||||||
#exponent = safe_exp(exponent)
|
#exponent = safe_exp(exponent)
|
||||||
#return self.pdf(f_star, y, y_m)*exponent
|
#return self.pdf(f_star, y, y_m)*exponent
|
||||||
|
|
||||||
#More stable in the log space
|
#More stable in the log space
|
||||||
return np.exp(self.logpdf(f_star, y, y_m) -(1./(2*v))*np.square(m-f_star))
|
return np.exp(self.logpdf(fi_star, yi, yi_m)
|
||||||
|
- 0.5*np.log(2*np.pi*vi)
|
||||||
|
- 0.5*np.square(mi-fi_star)/vi)
|
||||||
return f
|
return f
|
||||||
|
|
||||||
scaled_p_ystar, accuracy = zip(*[quad(integral_generator(y, m, v, y_m), -np.inf, np.inf) for y, m, v, y_m in zipped_values])
|
p_ystar, _ = zip(*[quad(integral_generator(yi, mi, vi, yi_m), -np.inf, np.inf)
|
||||||
scaled_p_ystar = np.array(scaled_p_ystar).reshape(-1,1)
|
for yi, mi, vi, yi_m in zipped_values])
|
||||||
p_ystar = scaled_p_ystar/np.sqrt(2*np.pi*var_star)
|
p_ystar = np.array(p_ystar).reshape(-1, 1)
|
||||||
return np.log(p_ystar)
|
return np.log(p_ystar)
|
||||||
|
|
||||||
def _moments_match_ep(self,obs,tau,v):
|
def _moments_match_ep(self,obs,tau,v):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue