Bug in linalg jitchol!!!

This commit is contained in:
Alan Saul 2015-02-09 19:35:46 +00:00
parent 7fbbdafdbf
commit 29d153e185
2 changed files with 41 additions and 6 deletions

View file

@ -0,0 +1,35 @@
import numpy as np
import scipy as sp
from ..util.linalg import jitchol
class LinalgTests(np.testing.TestCase):
def setUp(self):
#Create PD matrix
A = np.random.randn(20,100)
self.A = A.dot(A.T)
#compute Eigdecomp
vals, vectors = np.linalg.eig(self.A)
#Set smallest eigenval to be negative with 5 rounds worth of jitter
vals[vals.argmin()] = 0
default_jitter = 1e-6*np.mean(vals)
vals[vals.argmin()] = -default_jitter*(10**3.5)
self.A_corrupt = (vectors * vals).dot(vectors.T)
def test_jitchol_success(self):
"""
Expect 5 rounds of jitter to be added and for the recovered matrix to be
identical to the corrupted matrix apart from the jitter added to the diagonal
"""
L = jitchol(self.A_corrupt, maxtries=5)
A_new = L.dot(L.T)
diff = A_new - self.A_corrupt
np.testing.assert_allclose(diff, np.eye(A_new.shape[0])*np.diag(diff).mean(), atol=1e-13)
def test_jitchol_failure(self):
try:
""" Expecting an exception to be thrown as we expect it to require
5 rounds of jitter to be added to enforce PDness"""
jitchol(self.A_corrupt, maxtries=4)
return False
except sp.linalg.LinAlgError:
return True

View file

@ -82,6 +82,7 @@ def force_F_ordered(A):
# return jitchol(A+np.eye(A.shape[0])*jitter, maxtries-1) # return jitchol(A+np.eye(A.shape[0])*jitter, maxtries-1)
def jitchol(A, maxtries=5): def jitchol(A, maxtries=5):
A = np.ascontiguousarray(A) A = np.ascontiguousarray(A)
L, info = lapack.dpotrf(A, lower=1) L, info = lapack.dpotrf(A, lower=1)
@ -92,13 +93,16 @@ def jitchol(A, maxtries=5):
if np.any(diagA <= 0.): if np.any(diagA <= 0.):
raise linalg.LinAlgError, "not pd: non-positive diagonal elements" raise linalg.LinAlgError, "not pd: non-positive diagonal elements"
jitter = diagA.mean() * 1e-6 jitter = diagA.mean() * 1e-6
while maxtries > 0 and np.isfinite(jitter): num_tries = 0
while num_tries < maxtries and np.isfinite(jitter):
try: try:
print jitter
L = linalg.cholesky(A + np.eye(A.shape[0]) * jitter, lower=True) L = linalg.cholesky(A + np.eye(A.shape[0]) * jitter, lower=True)
return L
except: except:
jitter *= 10 jitter *= 10
finally: finally:
maxtries -= 1 num_tries += 1
raise linalg.LinAlgError, "not positive definite, even with jitter." raise linalg.LinAlgError, "not positive definite, even with jitter."
import traceback import traceback
try: raise try: raise
@ -108,10 +112,6 @@ def jitchol(A, maxtries=5):
import ipdb;ipdb.set_trace() import ipdb;ipdb.set_trace()
return L return L
# def dtrtri(L, lower=1): # def dtrtri(L, lower=1):
# """ # """
# Wrapper for lapack dtrtri function # Wrapper for lapack dtrtri function