mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-03 16:52:39 +02:00
pdinv now uses dpotri instead of dtrtri and dot
This commit is contained in:
parent
99ca20b77c
commit
6499a76e24
1 changed files with 23 additions and 2 deletions
|
|
@ -51,6 +51,26 @@ def _mdot_r(a,b):
|
|||
return np.dot(a,b)
|
||||
|
||||
def jitchol(A,maxtries=5):
|
||||
A = np.asfortranarray(A)
|
||||
L,info = linalg.lapack.flapack.dpotrf(A,lower=1)
|
||||
if info ==0:
|
||||
return L
|
||||
else:
|
||||
diagA = np.diag(A)
|
||||
if np.any(diagA<0.):
|
||||
raise linalg.LinAlgError, "not pd: negative diagonal elements"
|
||||
jitter= diagA.mean()*1e-6
|
||||
for i in range(1,maxtries+1):
|
||||
print 'Warning: adding jitter of '+str(jitter)
|
||||
try:
|
||||
return linalg.cholesky(A+np.eye(A.shape[0]).T*jitter, lower = True)
|
||||
except:
|
||||
jitter *= 10
|
||||
raise linalg.LinAlgError,"not positive definite, even with jitter."
|
||||
|
||||
|
||||
|
||||
def jitchol_old(A,maxtries=5):
|
||||
"""
|
||||
:param A : An almost pd square matrix
|
||||
|
||||
|
|
@ -71,7 +91,7 @@ def jitchol(A,maxtries=5):
|
|||
for i in range(1,maxtries+1):
|
||||
print 'Warning: adding jitter of '+str(jitter)
|
||||
try:
|
||||
return linalg.cholesky(A+np.eye(A.shape[0])*jitter, lower = True)
|
||||
return linalg.cholesky(A+np.eye(A.shape[0]).T*jitter, lower = True)
|
||||
except:
|
||||
jitter *= 10
|
||||
|
||||
|
|
@ -93,7 +113,8 @@ def pdinv(A):
|
|||
L = jitchol(A)
|
||||
logdet = 2.*np.sum(np.log(np.diag(L)))
|
||||
Li = chol_inv(L)
|
||||
Ai = np.dot(Li.T,Li) #TODO: get the flapack routine form multiplying triangular matrices
|
||||
Ai = linalg.lapack.flapack.dpotri(L)[0]
|
||||
Ai = np.tril(Ai) + np.tril(Ai,-1).T
|
||||
|
||||
return Ai, L, Li, logdet
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue