diff --git a/GPy/util/linalg.py b/GPy/util/linalg.py index 59f598f9..d82bb50f 100644 --- a/GPy/util/linalg.py +++ b/GPy/util/linalg.py @@ -51,6 +51,26 @@ def _mdot_r(a,b): return np.dot(a,b) def jitchol(A,maxtries=5): + A = np.asfortranarray(A) + L,info = linalg.lapack.flapack.dpotrf(A,lower=1) + if info ==0: + return L + else: + diagA = np.diag(A) + if np.any(diagA<0.): + raise linalg.LinAlgError, "not pd: negative diagonal elements" + jitter= diagA.mean()*1e-6 + for i in range(1,maxtries+1): + print 'Warning: adding jitter of '+str(jitter) + try: + return linalg.cholesky(A+np.eye(A.shape[0]).T*jitter, lower = True) + except: + jitter *= 10 + raise linalg.LinAlgError,"not positive definite, even with jitter." + + + +def jitchol_old(A,maxtries=5): """ :param A : An almost pd square matrix @@ -71,7 +91,7 @@ def jitchol(A,maxtries=5): for i in range(1,maxtries+1): print 'Warning: adding jitter of '+str(jitter) try: - return linalg.cholesky(A+np.eye(A.shape[0])*jitter, lower = True) + return linalg.cholesky(A+np.eye(A.shape[0]).T*jitter, lower = True) except: jitter *= 10 @@ -93,7 +113,8 @@ def pdinv(A): L = jitchol(A) logdet = 2.*np.sum(np.log(np.diag(L))) Li = chol_inv(L) - Ai = np.dot(Li.T,Li) #TODO: get the flapack routine form multiplying triangular matrices + Ai = linalg.lapack.flapack.dpotri(L)[0] + Ai = np.tril(Ai) + np.tril(Ai,-1).T return Ai, L, Li, logdet