Added logging for jitter so we know how much has been added and how many tries have been taken

This commit is contained in:
Alan Saul 2015-02-10 11:52:40 +00:00
parent cfdd72fc72
commit f690192384
2 changed files with 10 additions and 13 deletions

View file

@ -27,8 +27,10 @@ class LinalgTests(np.testing.TestCase):
def test_jitchol_failure(self): def test_jitchol_failure(self):
try: try:
""" Expecting an exception to be thrown as we expect it to require """
5 rounds of jitter to be added to enforce PDness""" Expecting an exception to be thrown as we expect it to require
5 rounds of jitter to be added to enforce PDness
"""
jitchol(self.A_corrupt, maxtries=4) jitchol(self.A_corrupt, maxtries=4)
return False return False
except sp.linalg.LinAlgError: except sp.linalg.LinAlgError:

View file

@ -93,24 +93,19 @@ def jitchol(A, maxtries=5):
if np.any(diagA <= 0.): if np.any(diagA <= 0.):
raise linalg.LinAlgError, "not pd: non-positive diagonal elements" raise linalg.LinAlgError, "not pd: non-positive diagonal elements"
jitter = diagA.mean() * 1e-6 jitter = diagA.mean() * 1e-6
num_tries = 0 num_tries = 1
while num_tries < maxtries and np.isfinite(jitter): while num_tries <= maxtries and np.isfinite(jitter):
try: try:
print jitter
L = linalg.cholesky(A + np.eye(A.shape[0]) * jitter, lower=True) L = linalg.cholesky(A + np.eye(A.shape[0]) * jitter, lower=True)
logging.warning('Added {} rounds of jitter, jitter of {:.10e}\n'.format(num_tries, jitter))
return L return L
except: except:
jitter *= 10 jitter *= 10
finally:
num_tries += 1 num_tries += 1
raise linalg.LinAlgError, "not positive definite, even with jitter."
import traceback import traceback
try: raise logging.warning('\n'.join(['Added {} rounds of jitter, jitter of {:.10e}'.format(num_tries-1, jitter),
except: ' in '+traceback.format_list(traceback.extract_stack(limit=2)[-2:-1])[0][2:]]))
logging.warning('\n'.join(['Added jitter of {:.10e}'.format(jitter), raise linalg.LinAlgError, "not positive definite, even with jitter."
' in '+traceback.format_list(traceback.extract_stack(limit=2)[-2:-1])[0][2:]]))
import ipdb;ipdb.set_trace()
return L
# def dtrtri(L, lower=1): # def dtrtri(L, lower=1):
# """ # """