From 8e69db51a290eb0335dd433f639ebbd557b44a71 Mon Sep 17 00:00:00 2001 From: Eero Siivola Date: Sun, 24 Jun 2018 12:41:36 +0300 Subject: [PATCH] Modified likelihoods test to better test multioutput likelihood --- GPy/testing/likelihood_tests.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/GPy/testing/likelihood_tests.py b/GPy/testing/likelihood_tests.py index 2c1a6e24..c665d6ab 100644 --- a/GPy/testing/likelihood_tests.py +++ b/GPy/testing/likelihood_tests.py @@ -129,6 +129,15 @@ class TestNoiseModels(object): self.Y_metadata = dict() self.Y_metadata['censored'] = censored self.Y_metadata['output_index'] = np.zeros((self.N,1), dtype=int) + self.Y_metadata2 = dict() + self.Y_metadata2['censored'] = censored + inds = np.zeros((self.N,1), dtype=int) + inds[5:10] = 1 + inds[10:] = 2 + self.Y_metadata2['output_index'] = inds + self.combY = self.Y + self.combY[10:] = np.where(self.binary_Y[10:] >0, self.binary_Y[10:], 0) + print(self.combY) #Make a bigger step as lower bound can be quite curved self.step = 1e-4 @@ -294,11 +303,11 @@ class TestNoiseModels(object): "laplace": True }, "multioutput_default": { - "model": GPy.likelihoods.MultioutputLikelihood([GPy.likelihoods.Bernoulli()]), + "model": GPy.likelihoods.MultioutputLikelihood([GPy.likelihoods.Gaussian(), GPy.likelihoods.Poisson(), GPy.likelihoods.Bernoulli()]), "link_f_constraints": [partial(self.constrain_bounded, lower=0, upper=1)], "laplace": True, - "Y": self.binary_Y, - "Y_metadata": self.Y_metadata, + "Y": self.combY, + "Y_metadata": self.Y_metadata2, "ep": True, "variational_expectations": True, } @@ -627,7 +636,7 @@ class TestNoiseModels(object): # Y = Y/Y.max() white_var = 1e-4 kernel = GPy.kern.RBF(X.shape[1]) + GPy.kern.White(X.shape[1]) - ep_inf = GPy.inference.latent_function_inference.EP() + ep_inf = GPy.inference.latent_function_inference.EP(always_reset=True) m = GPy.core.GP(X.copy(), Y.copy(), kernel=kernel, likelihood=model, Y_metadata=Y_metadata, inference_method=ep_inf) m['.*white'].constrain_fixed(white_var)