mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-12 21:42:39 +02:00
Bug in linalg jitchol!!!
This commit is contained in:
parent
7fbbdafdbf
commit
29d153e185
2 changed files with 41 additions and 6 deletions
35
GPy/testing/linalg_test.py
Normal file
35
GPy/testing/linalg_test.py
Normal file
|
|
@ -0,0 +1,35 @@
|
||||||
|
import numpy as np
|
||||||
|
import scipy as sp
|
||||||
|
from ..util.linalg import jitchol
|
||||||
|
|
||||||
|
class LinalgTests(np.testing.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
#Create PD matrix
|
||||||
|
A = np.random.randn(20,100)
|
||||||
|
self.A = A.dot(A.T)
|
||||||
|
#compute Eigdecomp
|
||||||
|
vals, vectors = np.linalg.eig(self.A)
|
||||||
|
#Set smallest eigenval to be negative with 5 rounds worth of jitter
|
||||||
|
vals[vals.argmin()] = 0
|
||||||
|
default_jitter = 1e-6*np.mean(vals)
|
||||||
|
vals[vals.argmin()] = -default_jitter*(10**3.5)
|
||||||
|
self.A_corrupt = (vectors * vals).dot(vectors.T)
|
||||||
|
|
||||||
|
def test_jitchol_success(self):
|
||||||
|
"""
|
||||||
|
Expect 5 rounds of jitter to be added and for the recovered matrix to be
|
||||||
|
identical to the corrupted matrix apart from the jitter added to the diagonal
|
||||||
|
"""
|
||||||
|
L = jitchol(self.A_corrupt, maxtries=5)
|
||||||
|
A_new = L.dot(L.T)
|
||||||
|
diff = A_new - self.A_corrupt
|
||||||
|
np.testing.assert_allclose(diff, np.eye(A_new.shape[0])*np.diag(diff).mean(), atol=1e-13)
|
||||||
|
|
||||||
|
def test_jitchol_failure(self):
|
||||||
|
try:
|
||||||
|
""" Expecting an exception to be thrown as we expect it to require
|
||||||
|
5 rounds of jitter to be added to enforce PDness"""
|
||||||
|
jitchol(self.A_corrupt, maxtries=4)
|
||||||
|
return False
|
||||||
|
except sp.linalg.LinAlgError:
|
||||||
|
return True
|
||||||
|
|
@ -82,6 +82,7 @@ def force_F_ordered(A):
|
||||||
|
|
||||||
# return jitchol(A+np.eye(A.shape[0])*jitter, maxtries-1)
|
# return jitchol(A+np.eye(A.shape[0])*jitter, maxtries-1)
|
||||||
|
|
||||||
|
|
||||||
def jitchol(A, maxtries=5):
|
def jitchol(A, maxtries=5):
|
||||||
A = np.ascontiguousarray(A)
|
A = np.ascontiguousarray(A)
|
||||||
L, info = lapack.dpotrf(A, lower=1)
|
L, info = lapack.dpotrf(A, lower=1)
|
||||||
|
|
@ -92,13 +93,16 @@ def jitchol(A, maxtries=5):
|
||||||
if np.any(diagA <= 0.):
|
if np.any(diagA <= 0.):
|
||||||
raise linalg.LinAlgError, "not pd: non-positive diagonal elements"
|
raise linalg.LinAlgError, "not pd: non-positive diagonal elements"
|
||||||
jitter = diagA.mean() * 1e-6
|
jitter = diagA.mean() * 1e-6
|
||||||
while maxtries > 0 and np.isfinite(jitter):
|
num_tries = 0
|
||||||
|
while num_tries < maxtries and np.isfinite(jitter):
|
||||||
try:
|
try:
|
||||||
|
print jitter
|
||||||
L = linalg.cholesky(A + np.eye(A.shape[0]) * jitter, lower=True)
|
L = linalg.cholesky(A + np.eye(A.shape[0]) * jitter, lower=True)
|
||||||
|
return L
|
||||||
except:
|
except:
|
||||||
jitter *= 10
|
jitter *= 10
|
||||||
finally:
|
finally:
|
||||||
maxtries -= 1
|
num_tries += 1
|
||||||
raise linalg.LinAlgError, "not positive definite, even with jitter."
|
raise linalg.LinAlgError, "not positive definite, even with jitter."
|
||||||
import traceback
|
import traceback
|
||||||
try: raise
|
try: raise
|
||||||
|
|
@ -108,10 +112,6 @@ def jitchol(A, maxtries=5):
|
||||||
import ipdb;ipdb.set_trace()
|
import ipdb;ipdb.set_trace()
|
||||||
return L
|
return L
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# def dtrtri(L, lower=1):
|
# def dtrtri(L, lower=1):
|
||||||
# """
|
# """
|
||||||
# Wrapper for lapack dtrtri function
|
# Wrapper for lapack dtrtri function
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue