mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-24 14:15:14 +02:00
Factored out lapack into utils so we can check version and give deprecation warnings
This commit is contained in:
parent
b142b68876
commit
e587618339
6 changed files with 73 additions and 38 deletions
|
|
@ -11,6 +11,12 @@ import types
|
|||
import ctypes
|
||||
from ctypes import byref, c_char, c_int, c_double # TODO
|
||||
# import scipy.lib.lapack
|
||||
import scipy
|
||||
|
||||
if np.all(np.float64((scipy.__version__).split('.')[:2]) >= np.array([0, 10])):
|
||||
import scipy.linalg.lapack as lapack
|
||||
else:
|
||||
import scipy.linalg.lapack.flapack as lapack
|
||||
|
||||
try:
|
||||
_blaslib = ctypes.cdll.LoadLibrary(np.core._dotblas.__file__) # @UndefinedVariable
|
||||
|
|
@ -18,6 +24,36 @@ try:
|
|||
except:
|
||||
_blas_available = False
|
||||
|
||||
def dtrtrs(A, B, lower=0, trans=0, unitdiag=0, overwrite_b=0):
|
||||
"""Wrapper for lapack dtrtrs function
|
||||
|
||||
:param A: Matrix A
|
||||
:param B: Matrix B
|
||||
:param lower: is matrix lower (true) or upper (false)
|
||||
:returns:
|
||||
"""
|
||||
return lapack.dtrtrs(A, B, lower=lower, trans=trans, unitdiag=unitdiag, overwrite_b=overwrite_b)
|
||||
|
||||
def dpotrs(A, B, overwrite_b=0, lower=0):
|
||||
"""Wrapper for lapack dpotrs function
|
||||
|
||||
:param A: Matrix A
|
||||
:param B: Matrix B
|
||||
:param lower: is matrix lower (true) or upper (false)
|
||||
:returns:
|
||||
"""
|
||||
return lapack.dpotrs(A, B, overwrite_b=overwrite_b, lower=lower)
|
||||
|
||||
def dpotri(A, B, overwrite_b=0, lower=0):
|
||||
"""Wrapper for lapack dpotri function
|
||||
|
||||
:param A: Matrix A
|
||||
:param B: Matrix B
|
||||
:param lower: is matrix lower (true) or upper (false)
|
||||
:returns:
|
||||
"""
|
||||
return lapack.dpotri(A, B, overwrite_b=overwrite_b, lower=lower)
|
||||
|
||||
def trace_dot(a, b):
|
||||
"""
|
||||
efficiently compute the trace of the matrix product of a and b
|
||||
|
|
@ -56,7 +92,7 @@ def _mdot_r(a, b):
|
|||
|
||||
def jitchol(A, maxtries=5):
|
||||
A = np.asfortranarray(A)
|
||||
L, info = linalg.lapack.flapack.dpotrf(A, lower=1)
|
||||
L, info = lapack.dpotrf(A, lower=1)
|
||||
if info == 0:
|
||||
return L
|
||||
else:
|
||||
|
|
@ -117,7 +153,7 @@ def pdinv(A, *args):
|
|||
L = jitchol(A, *args)
|
||||
logdet = 2.*np.sum(np.log(np.diag(L)))
|
||||
Li = chol_inv(L)
|
||||
Ai, _ = linalg.lapack.flapack.dpotri(L)
|
||||
Ai, _ = lapack.dpotri(L)
|
||||
# Ai = np.tril(Ai) + np.tril(Ai,-1).T
|
||||
symmetrify(Ai)
|
||||
|
||||
|
|
@ -133,7 +169,7 @@ def chol_inv(L):
|
|||
|
||||
"""
|
||||
|
||||
return linalg.lapack.flapack.dtrtri(L, lower=True)[0]
|
||||
return lapack.dtrtri(L, lower=True)[0]
|
||||
|
||||
|
||||
def multiple_pdinv(A):
|
||||
|
|
@ -150,7 +186,7 @@ def multiple_pdinv(A):
|
|||
N = A.shape[-1]
|
||||
chols = [jitchol(A[:, :, i]) for i in range(N)]
|
||||
halflogdets = [np.sum(np.log(np.diag(L[0]))) for L in chols]
|
||||
invs = [linalg.lapack.flapack.dpotri(L[0], True)[0] for L in chols]
|
||||
invs = [lapack.dpotri(L[0], True)[0] for L in chols]
|
||||
invs = [np.triu(I) + np.triu(I, 1).T for I in invs]
|
||||
return np.dstack(invs), np.array(halflogdets)
|
||||
|
||||
|
|
@ -351,9 +387,9 @@ def cholupdate(L, x):
|
|||
def backsub_both_sides(L, X, transpose='left'):
|
||||
""" Return L^-T * X * L^-1, assumuing X is symmetrical and L is lower cholesky"""
|
||||
if transpose == 'left':
|
||||
tmp, _ = linalg.lapack.flapack.dtrtrs(L, np.asfortranarray(X), lower=1, trans=1)
|
||||
return linalg.lapack.flapack.dtrtrs(L, np.asfortranarray(tmp.T), lower=1, trans=1)[0].T
|
||||
tmp, _ = lapack.dtrtrs(L, np.asfortranarray(X), lower=1, trans=1)
|
||||
return lapack.dtrtrs(L, np.asfortranarray(tmp.T), lower=1, trans=1)[0].T
|
||||
else:
|
||||
tmp, _ = linalg.lapack.flapack.dtrtrs(L, np.asfortranarray(X), lower=1, trans=0)
|
||||
return linalg.lapack.flapack.dtrtrs(L, np.asfortranarray(tmp.T), lower=1, trans=0)[0].T
|
||||
tmp, _ = lapack.dtrtrs(L, np.asfortranarray(X), lower=1, trans=0)
|
||||
return lapack.dtrtrs(L, np.asfortranarray(tmp.T), lower=1, trans=0)[0].T
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue