mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-14 14:32:37 +02:00
Added logging for jitter so we know how much has been added and how many tries have been taken
This commit is contained in:
parent
cfdd72fc72
commit
f690192384
2 changed files with 10 additions and 13 deletions
|
|
@ -27,8 +27,10 @@ class LinalgTests(np.testing.TestCase):
|
||||||
|
|
||||||
def test_jitchol_failure(self):
|
def test_jitchol_failure(self):
|
||||||
try:
|
try:
|
||||||
""" Expecting an exception to be thrown as we expect it to require
|
"""
|
||||||
5 rounds of jitter to be added to enforce PDness"""
|
Expecting an exception to be thrown as we expect it to require
|
||||||
|
5 rounds of jitter to be added to enforce PDness
|
||||||
|
"""
|
||||||
jitchol(self.A_corrupt, maxtries=4)
|
jitchol(self.A_corrupt, maxtries=4)
|
||||||
return False
|
return False
|
||||||
except sp.linalg.LinAlgError:
|
except sp.linalg.LinAlgError:
|
||||||
|
|
|
||||||
|
|
@ -93,24 +93,19 @@ def jitchol(A, maxtries=5):
|
||||||
if np.any(diagA <= 0.):
|
if np.any(diagA <= 0.):
|
||||||
raise linalg.LinAlgError, "not pd: non-positive diagonal elements"
|
raise linalg.LinAlgError, "not pd: non-positive diagonal elements"
|
||||||
jitter = diagA.mean() * 1e-6
|
jitter = diagA.mean() * 1e-6
|
||||||
num_tries = 0
|
num_tries = 1
|
||||||
while num_tries < maxtries and np.isfinite(jitter):
|
while num_tries <= maxtries and np.isfinite(jitter):
|
||||||
try:
|
try:
|
||||||
print jitter
|
|
||||||
L = linalg.cholesky(A + np.eye(A.shape[0]) * jitter, lower=True)
|
L = linalg.cholesky(A + np.eye(A.shape[0]) * jitter, lower=True)
|
||||||
|
logging.warning('Added {} rounds of jitter, jitter of {:.10e}\n'.format(num_tries, jitter))
|
||||||
return L
|
return L
|
||||||
except:
|
except:
|
||||||
jitter *= 10
|
jitter *= 10
|
||||||
finally:
|
|
||||||
num_tries += 1
|
num_tries += 1
|
||||||
raise linalg.LinAlgError, "not positive definite, even with jitter."
|
|
||||||
import traceback
|
import traceback
|
||||||
try: raise
|
logging.warning('\n'.join(['Added {} rounds of jitter, jitter of {:.10e}'.format(num_tries-1, jitter),
|
||||||
except:
|
' in '+traceback.format_list(traceback.extract_stack(limit=2)[-2:-1])[0][2:]]))
|
||||||
logging.warning('\n'.join(['Added jitter of {:.10e}'.format(jitter),
|
raise linalg.LinAlgError, "not positive definite, even with jitter."
|
||||||
' in '+traceback.format_list(traceback.extract_stack(limit=2)[-2:-1])[0][2:]]))
|
|
||||||
import ipdb;ipdb.set_trace()
|
|
||||||
return L
|
|
||||||
|
|
||||||
# def dtrtri(L, lower=1):
|
# def dtrtri(L, lower=1):
|
||||||
# """
|
# """
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue