diff --git a/GPy/kern/_src/prod.py b/GPy/kern/_src/prod.py index dd9a5fe4..e3776838 100644 --- a/GPy/kern/_src/prod.py +++ b/GPy/kern/_src/prod.py @@ -5,6 +5,7 @@ import numpy as np from kern import CombinationKernel from ...util.caching import Cache_this import itertools +import operator class Prod(CombinationKernel): """ @@ -42,9 +43,19 @@ class Prod(CombinationKernel): return reduce(np.multiply, (p.Kdiag(X) for p in which_parts)) def update_gradients_full(self, dL_dK, X, X2=None): + np.seterr(invalid='raise') k = self.K(X,X2)*dL_dK - for p in self.parts: - p.update_gradients_full(k/p.K(X,X2),X,X2) + try: + for p in self.parts: + p.update_gradients_full(k/p.K(X,X2),X,X2) + except FloatingPointError: + np.seterr(invalid='warn') + print "Gradient warning: falling back to slow version due to zero-valued kernel" + for combination in itertools.combinations(self.parts, len(self.parts) - 1): + prod = reduce(operator.mul, [p.K(X, X2) for p in combination]) + to_update = list(set(self.parts) - set(combination))[0] + to_update.update_gradients_full(dL_dK * prod, X, X2) + def update_gradients_diag(self, dL_dKdiag, X): k = self.Kdiag(X)*dL_dKdiag diff --git a/GPy/testing/kernel_tests.py b/GPy/testing/kernel_tests.py index c1bb9265..387047b6 100644 --- a/GPy/testing/kernel_tests.py +++ b/GPy/testing/kernel_tests.py @@ -400,12 +400,27 @@ class Coregionalize_weave_test(unittest.TestCase): #reset the weave state for any other tests GPy.util.config.config.set('weave', 'working', 'False') +class KernelTestsProductWithZeroValues(unittest.TestCase): + + def test_zero_valued_kernel(self): + X = np.array([[0,1],[1,0]]) + Y = np.array([[1],[10]]) + lin = GPy.kern.Linear(2) + bias = GPy.kern.Bias(2) + k = lin * bias + #k = lin + m = GPy.models.GPRegression(X, Y, kernel=k) + #m['mul.bias.variance'].constrain_fixed(0) + m.optimize(messages=False) if __name__ == "__main__": print "Running unit tests, please be (very) patient..." unittest.main() + #suite = unittest.TestLoader().loadTestsFromTestCase(KernelTestsProductWithZeroValues) + #unittest.TextTestRunner().run(suite) + # np.random.seed(0) # N0 = 3 # N1 = 9