pdinv now uses dpotri instead of dtrtri and dot

This commit is contained in:
James Hensman 2013-04-10 16:50:34 +01:00
parent 99ca20b77c
commit 6499a76e24

View file

@ -51,6 +51,26 @@ def _mdot_r(a,b):
return np.dot(a,b)
def jitchol(A,maxtries=5):
A = np.asfortranarray(A)
L,info = linalg.lapack.flapack.dpotrf(A,lower=1)
if info ==0:
return L
else:
diagA = np.diag(A)
if np.any(diagA<0.):
raise linalg.LinAlgError, "not pd: negative diagonal elements"
jitter= diagA.mean()*1e-6
for i in range(1,maxtries+1):
print 'Warning: adding jitter of '+str(jitter)
try:
return linalg.cholesky(A+np.eye(A.shape[0]).T*jitter, lower = True)
except:
jitter *= 10
raise linalg.LinAlgError,"not positive definite, even with jitter."
def jitchol_old(A,maxtries=5):
"""
:param A : An almost pd square matrix
@ -71,7 +91,7 @@ def jitchol(A,maxtries=5):
for i in range(1,maxtries+1):
print 'Warning: adding jitter of '+str(jitter)
try:
return linalg.cholesky(A+np.eye(A.shape[0])*jitter, lower = True)
return linalg.cholesky(A+np.eye(A.shape[0]).T*jitter, lower = True)
except:
jitter *= 10
@ -93,7 +113,8 @@ def pdinv(A):
L = jitchol(A)
logdet = 2.*np.sum(np.log(np.diag(L)))
Li = chol_inv(L)
Ai = np.dot(Li.T,Li) #TODO: get the flapack routine form multiplying triangular matrices
Ai = linalg.lapack.flapack.dpotri(L)[0]
Ai = np.tril(Ai) + np.tril(Ai,-1).T
return Ai, L, Li, logdet