From 8b4274339ad034aeede9b926ac47bad89ae2f397 Mon Sep 17 00:00:00 2001 From: Daniel Beck Date: Mon, 9 Feb 2015 09:28:53 +1100 Subject: [PATCH] added decorator that changes numpy invalid op warning to exception --- GPy/kern/_src/prod.py | 20 +++++++++++++++++--- GPy/testing/kernel_tests.py | 7 ++++--- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/GPy/kern/_src/prod.py b/GPy/kern/_src/prod.py index e3776838..4f9f5ea6 100644 --- a/GPy/kern/_src/prod.py +++ b/GPy/kern/_src/prod.py @@ -7,6 +7,19 @@ from ...util.caching import Cache_this import itertools import operator + +def numpy_invalid_op_as_exception(func): + """ + A decorator that allows catching numpy invalid operations + as exceptions (the default behaviour is raising warnings). + """ + def func_wrapper(*args, **kwargs): + np.seterr(invalid='raise') + func(*args, **kwargs) + np.seterr(invalid='warn') + return func_wrapper + + class Prod(CombinationKernel): """ Computes the product of 2 kernels @@ -42,15 +55,14 @@ class Prod(CombinationKernel): which_parts = self.parts return reduce(np.multiply, (p.Kdiag(X) for p in which_parts)) + @numpy_invalid_op_as_exception def update_gradients_full(self, dL_dK, X, X2=None): - np.seterr(invalid='raise') k = self.K(X,X2)*dL_dK try: for p in self.parts: p.update_gradients_full(k/p.K(X,X2),X,X2) except FloatingPointError: - np.seterr(invalid='warn') - print "Gradient warning: falling back to slow version due to zero-valued kernel" + #print "WARNING: gradient calculation falling back to slow version due to zero-valued kernel" for combination in itertools.combinations(self.parts, len(self.parts) - 1): prod = reduce(operator.mul, [p.K(X, X2) for p in combination]) to_update = list(set(self.parts) - set(combination))[0] @@ -75,3 +87,5 @@ class Prod(CombinationKernel): for p in self.parts: target += p.gradients_X_diag(k/p.Kdiag(X),X) return target + + diff --git a/GPy/testing/kernel_tests.py b/GPy/testing/kernel_tests.py index 387047b6..ac6d7ab4 100644 --- a/GPy/testing/kernel_tests.py +++ b/GPy/testing/kernel_tests.py @@ -408,10 +408,11 @@ class KernelTestsProductWithZeroValues(unittest.TestCase): lin = GPy.kern.Linear(2) bias = GPy.kern.Bias(2) k = lin * bias - #k = lin m = GPy.models.GPRegression(X, Y, kernel=k) - #m['mul.bias.variance'].constrain_fixed(0) - m.optimize(messages=False) + try: + m.optimize() + except np.linalg.LinAlgError: + self.fail("Zero-valued kernel raised exception!")