diff --git a/GPy/testing/psi_stat_tests.py b/GPy/testing/psi_stat_tests.py index 93f9867c..22737ca1 100644 --- a/GPy/testing/psi_stat_tests.py +++ b/GPy/testing/psi_stat_tests.py @@ -39,10 +39,6 @@ class PsiStatModel(model): self.Z = x[start: end].reshape(self.M, self.Q) self.kern._set_params(x[end:]) def log_likelihood(self): -# if '2' in self.which: -# norm = self.N ** 2 -# else: # '0', '1' in self.which: -# norm = self.N return self.kern.__getattribute__(self.which)(self.Z, self.X, self.X_variance).sum() def _log_likelihood_gradients(self): psi_ = self.kern.__getattribute__(self.which)(self.Z, self.X, self.X_variance) @@ -64,23 +60,27 @@ class Test(unittest.TestCase): Z = numpy.random.permutation(X)[:M] Y = X.dot(numpy.random.randn(Q, D)) + kernels = [GPy.kern.linear(Q), GPy.kern.rbf(Q), GPy.kern.bias(Q), + GPy.kern.linear(Q) + GPy.kern.bias(Q), + GPy.kern.rbf(Q) + GPy.kern.bias(Q)] + def testPsi0(self): - kernel = GPy.kern.linear(Q) - m = PsiStatModel('psi0', X=X, X_variance=X_var, Z=Z, - M=M, kernel=kernel, mu_or_S=0, dL=numpy.ones((1))) - assert m.checkgrad(), "linear x psi0" + for k in self.kernels: + m = PsiStatModel('psi0', X=self.X, X_variance=self.X_var, Z=self.Z, + M=self.M, kernel=k) + assert m.checkgrad(), "{} x psi0".format("+".join(map(lambda x: x.name, k.parts))) def testPsi1(self): - kernel = GPy.kern.linear(Q) - m = PsiStatModel('psi1', X=X, X_variance=X_var, Z=Z, - M=M, kernel=kernel, mu_or_S=0, dL=numpy.ones((1, 1))) - assert(m.checkgrad()) + for k in self.kernels: + m = PsiStatModel('psi0', X=self.X, X_variance=self.X_var, Z=self.Z, + M=self.M, kernel=k) + assert m.checkgrad(), "{} x psi1".format("+".join(map(lambda x: x.name, k.parts))) def testPsi2(self): - kernel = GPy.kern.linear(Q) - m = PsiStatModel('psi2', X=X, X_variance=X_var, Z=Z, - M=M, kernel=kernel, mu_or_S=0, dL=numpy.ones((1, 1, 1))) - assert(m.checkgrad()) + for k in self.kernels: + m = PsiStatModel('psi0', X=self.X, X_variance=self.X_var, Z=self.Z, + M=self.M, kernel=k) + assert m.checkgrad(), "{} x psi2".format("+".join(map(lambda x: x.name, k.parts))) if __name__ == "__main__":