Added logging for jitter so we know how much has been added and how many tries have been taken

This commit is contained in:
Alan Saul 2015-02-10 11:52:40 +00:00
parent cfdd72fc72
commit f690192384
2 changed files with 10 additions and 13 deletions

View file

@ -27,8 +27,10 @@ class LinalgTests(np.testing.TestCase):
def test_jitchol_failure(self):
try:
""" Expecting an exception to be thrown as we expect it to require
5 rounds of jitter to be added to enforce PDness"""
"""
Expecting an exception to be thrown as we expect it to require
5 rounds of jitter to be added to enforce PDness
"""
jitchol(self.A_corrupt, maxtries=4)
return False
except sp.linalg.LinAlgError:

View file

@ -93,24 +93,19 @@ def jitchol(A, maxtries=5):
if np.any(diagA <= 0.):
raise linalg.LinAlgError, "not pd: non-positive diagonal elements"
jitter = diagA.mean() * 1e-6
num_tries = 0
while num_tries < maxtries and np.isfinite(jitter):
num_tries = 1
while num_tries <= maxtries and np.isfinite(jitter):
try:
print jitter
L = linalg.cholesky(A + np.eye(A.shape[0]) * jitter, lower=True)
logging.warning('Added {} rounds of jitter, jitter of {:.10e}\n'.format(num_tries, jitter))
return L
except:
jitter *= 10
finally:
num_tries += 1
raise linalg.LinAlgError, "not positive definite, even with jitter."
import traceback
try: raise
except:
logging.warning('\n'.join(['Added jitter of {:.10e}'.format(jitter),
' in '+traceback.format_list(traceback.extract_stack(limit=2)[-2:-1])[0][2:]]))
import ipdb;ipdb.set_trace()
return L
logging.warning('\n'.join(['Added {} rounds of jitter, jitter of {:.10e}'.format(num_tries-1, jitter),
' in '+traceback.format_list(traceback.extract_stack(limit=2)[-2:-1])[0][2:]]))
raise linalg.LinAlgError, "not positive definite, even with jitter."
# def dtrtri(L, lower=1):
# """