From 952851de88c2a0054502c2fa0b98109ee867ecde Mon Sep 17 00:00:00 2001 From: Alan Saul Date: Mon, 9 Feb 2015 19:35:46 +0000 Subject: [PATCH] Bug in linalg jitchol!!! --- GPy/testing/linalg_test.py | 35 +++++++++++++++++++++++++++++++++++ GPy/util/linalg.py | 12 ++++++------ 2 files changed, 41 insertions(+), 6 deletions(-) create mode 100644 GPy/testing/linalg_test.py diff --git a/GPy/testing/linalg_test.py b/GPy/testing/linalg_test.py new file mode 100644 index 00000000..b734f6af --- /dev/null +++ b/GPy/testing/linalg_test.py @@ -0,0 +1,35 @@ +import numpy as np +import scipy as sp +from ..util.linalg import jitchol + +class LinalgTests(np.testing.TestCase): + def setUp(self): + #Create PD matrix + A = np.random.randn(20,100) + self.A = A.dot(A.T) + #compute Eigdecomp + vals, vectors = np.linalg.eig(self.A) + #Set smallest eigenval to be negative with 5 rounds worth of jitter + vals[vals.argmin()] = 0 + default_jitter = 1e-6*np.mean(vals) + vals[vals.argmin()] = -default_jitter*(10**3.5) + self.A_corrupt = (vectors * vals).dot(vectors.T) + + def test_jitchol_success(self): + """ + Expect 5 rounds of jitter to be added and for the recovered matrix to be + identical to the corrupted matrix apart from the jitter added to the diagonal + """ + L = jitchol(self.A_corrupt, maxtries=5) + A_new = L.dot(L.T) + diff = A_new - self.A_corrupt + np.testing.assert_allclose(diff, np.eye(A_new.shape[0])*np.diag(diff).mean(), atol=1e-13) + + def test_jitchol_failure(self): + try: + """ Expecting an exception to be thrown as we expect it to require + 5 rounds of jitter to be added to enforce PDness""" + jitchol(self.A_corrupt, maxtries=4) + return False + except sp.linalg.LinAlgError: + return True diff --git a/GPy/util/linalg.py b/GPy/util/linalg.py index dffd438a..2c02357c 100644 --- a/GPy/util/linalg.py +++ b/GPy/util/linalg.py @@ -82,6 +82,7 @@ def force_F_ordered(A): # return jitchol(A+np.eye(A.shape[0])*jitter, maxtries-1) + def jitchol(A, maxtries=5): A = np.ascontiguousarray(A) L, info = lapack.dpotrf(A, lower=1) @@ -92,13 +93,16 @@ def jitchol(A, maxtries=5): if np.any(diagA <= 0.): raise linalg.LinAlgError, "not pd: non-positive diagonal elements" jitter = diagA.mean() * 1e-6 - while maxtries > 0 and np.isfinite(jitter): + num_tries = 0 + while num_tries < maxtries and np.isfinite(jitter): try: + print jitter L = linalg.cholesky(A + np.eye(A.shape[0]) * jitter, lower=True) + return L except: jitter *= 10 finally: - maxtries -= 1 + num_tries += 1 raise linalg.LinAlgError, "not positive definite, even with jitter." import traceback try: raise @@ -108,10 +112,6 @@ def jitchol(A, maxtries=5): import ipdb;ipdb.set_trace() return L - - - - # def dtrtri(L, lower=1): # """ # Wrapper for lapack dtrtri function