diff --git a/GPy/testing/linalg_test.py b/GPy/testing/linalg_test.py index b734f6af..8e103795 100644 --- a/GPy/testing/linalg_test.py +++ b/GPy/testing/linalg_test.py @@ -27,8 +27,10 @@ class LinalgTests(np.testing.TestCase): def test_jitchol_failure(self): try: - """ Expecting an exception to be thrown as we expect it to require - 5 rounds of jitter to be added to enforce PDness""" + """ + Expecting an exception to be thrown as we expect it to require + 5 rounds of jitter to be added to enforce PDness + """ jitchol(self.A_corrupt, maxtries=4) return False except sp.linalg.LinAlgError: diff --git a/GPy/util/linalg.py b/GPy/util/linalg.py index 2c02357c..b148f2f4 100644 --- a/GPy/util/linalg.py +++ b/GPy/util/linalg.py @@ -93,24 +93,19 @@ def jitchol(A, maxtries=5): if np.any(diagA <= 0.): raise linalg.LinAlgError, "not pd: non-positive diagonal elements" jitter = diagA.mean() * 1e-6 - num_tries = 0 - while num_tries < maxtries and np.isfinite(jitter): + num_tries = 1 + while num_tries <= maxtries and np.isfinite(jitter): try: - print jitter L = linalg.cholesky(A + np.eye(A.shape[0]) * jitter, lower=True) + logging.warning('Added {} rounds of jitter, jitter of {:.10e}\n'.format(num_tries, jitter)) return L except: jitter *= 10 - finally: num_tries += 1 - raise linalg.LinAlgError, "not positive definite, even with jitter." import traceback - try: raise - except: - logging.warning('\n'.join(['Added jitter of {:.10e}'.format(jitter), - ' in '+traceback.format_list(traceback.extract_stack(limit=2)[-2:-1])[0][2:]])) - import ipdb;ipdb.set_trace() - return L + logging.warning('\n'.join(['Added {} rounds of jitter, jitter of {:.10e}'.format(num_tries-1, jitter), + ' in '+traceback.format_list(traceback.extract_stack(limit=2)[-2:-1])[0][2:]])) + raise linalg.LinAlgError, "not positive definite, even with jitter." # def dtrtri(L, lower=1): # """