mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-11 21:12:38 +02:00
FIX: now Scipy 0.16 is required, removing fixes for older versions. Accessing blas through the scipy interface
This commit is contained in:
parent
f221a3b1fa
commit
ba2ea3eb73
1 changed files with 12 additions and 53 deletions
|
|
@ -7,48 +7,16 @@
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy import linalg
|
from scipy import linalg
|
||||||
import types
|
from scipy.linalg import lapack, blas
|
||||||
|
|
||||||
import ctypes
|
import ctypes
|
||||||
from ctypes import byref, c_char, c_int, c_double # TODO
|
from ctypes import byref, c_char, c_int, c_double # TODO
|
||||||
import scipy
|
|
||||||
import warnings
|
|
||||||
import os
|
|
||||||
from .config import config
|
from .config import config
|
||||||
import logging
|
import logging
|
||||||
from . import linalg_cython
|
from . import linalg_cython
|
||||||
|
|
||||||
|
|
||||||
_scipyversion = np.float64((scipy.__version__).split('.')[:2])
|
|
||||||
_fix_dpotri_scipy_bug = True
|
|
||||||
if np.all(_scipyversion >= np.array([0, 14])):
|
|
||||||
from scipy.linalg import lapack
|
|
||||||
_fix_dpotri_scipy_bug = False
|
|
||||||
elif np.all(_scipyversion >= np.array([0, 12])):
|
|
||||||
#import scipy.linalg.lapack.clapack as lapack
|
|
||||||
from scipy.linalg import lapack
|
|
||||||
else:
|
|
||||||
from scipy.linalg.lapack import flapack as lapack
|
|
||||||
|
|
||||||
if config.getboolean('anaconda', 'installed') and config.getboolean('anaconda', 'MKL'):
|
|
||||||
try:
|
|
||||||
anaconda_path = str(config.get('anaconda', 'location'))
|
|
||||||
mkl_rt = ctypes.cdll.LoadLibrary(os.path.join(anaconda_path, 'DLLs', 'mkl_rt.dll'))
|
|
||||||
dsyrk = mkl_rt.dsyrk
|
|
||||||
dsyr = mkl_rt.dsyr
|
|
||||||
_blas_available = True
|
|
||||||
print('anaconda installed and mkl is loaded')
|
|
||||||
except:
|
|
||||||
_blas_available = False
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
_blaslib = ctypes.cdll.LoadLibrary(np.core._dotblas.__file__) # @UndefinedVariable
|
|
||||||
dsyrk = _blaslib.dsyrk_
|
|
||||||
dsyr = _blaslib.dsyr_
|
|
||||||
_blas_available = True
|
|
||||||
except AttributeError as e:
|
|
||||||
_blas_available = False
|
|
||||||
warnings.warn("warning: caught this exception:" + str(e))
|
|
||||||
|
|
||||||
def force_F_ordered_symmetric(A):
|
def force_F_ordered_symmetric(A):
|
||||||
"""
|
"""
|
||||||
return a F ordered version of A, assuming A is symmetric
|
return a F ordered version of A, assuming A is symmetric
|
||||||
|
|
@ -169,9 +137,6 @@ def dpotri(A, lower=1):
|
||||||
:returns: A inverse
|
:returns: A inverse
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if _fix_dpotri_scipy_bug:
|
|
||||||
assert lower==1, "scipy linalg behaviour is very weird. please use lower, fortran ordered arrays"
|
|
||||||
lower = 0
|
|
||||||
|
|
||||||
A = force_F_ordered(A)
|
A = force_F_ordered(A)
|
||||||
R, info = lapack.dpotri(A, lower=lower) #needs to be zero here, seems to be a scipy bug
|
R, info = lapack.dpotri(A, lower=lower) #needs to be zero here, seems to be a scipy bug
|
||||||
|
|
@ -300,8 +265,8 @@ def pca(Y, input_dim):
|
||||||
Z = linalg.svd(Y - Y.mean(axis=0), full_matrices=False)
|
Z = linalg.svd(Y - Y.mean(axis=0), full_matrices=False)
|
||||||
[X, W] = [Z[0][:, 0:input_dim], np.dot(np.diag(Z[1]), Z[2]).T[:, 0:input_dim]]
|
[X, W] = [Z[0][:, 0:input_dim], np.dot(np.diag(Z[1]), Z[2]).T[:, 0:input_dim]]
|
||||||
v = X.std(axis=0)
|
v = X.std(axis=0)
|
||||||
X /= v;
|
X /= v
|
||||||
W *= v;
|
W *= v
|
||||||
return X, W.T
|
return X, W.T
|
||||||
|
|
||||||
def ppca(Y, Q, iterations=100):
|
def ppca(Y, Q, iterations=100):
|
||||||
|
|
@ -362,19 +327,15 @@ def tdot_blas(mat, out=None):
|
||||||
BETA = c_double(0.0)
|
BETA = c_double(0.0)
|
||||||
C = out.ctypes.data_as(ctypes.c_void_p)
|
C = out.ctypes.data_as(ctypes.c_void_p)
|
||||||
LDC = c_int(np.max(out.strides) // 8)
|
LDC = c_int(np.max(out.strides) // 8)
|
||||||
dsyrk(byref(UPLO), byref(TRANS), byref(N), byref(K),
|
blas.dsyrk(byref(UPLO), byref(TRANS), byref(N), byref(K),
|
||||||
byref(ALPHA), A, byref(LDA), byref(BETA), C, byref(LDC))
|
byref(ALPHA), A, byref(LDA), byref(BETA), C, byref(LDC))
|
||||||
|
|
||||||
symmetrify(out, upper=True)
|
symmetrify(out, upper=True)
|
||||||
|
|
||||||
|
|
||||||
return np.ascontiguousarray(out)
|
return np.ascontiguousarray(out)
|
||||||
|
|
||||||
def tdot(*args, **kwargs):
|
def tdot(*args, **kwargs):
|
||||||
if _blas_available:
|
return tdot_blas(*args, **kwargs)
|
||||||
return tdot_blas(*args, **kwargs)
|
|
||||||
else:
|
|
||||||
return tdot_numpy(*args, **kwargs)
|
|
||||||
|
|
||||||
def DSYR_blas(A, x, alpha=1.):
|
def DSYR_blas(A, x, alpha=1.):
|
||||||
"""
|
"""
|
||||||
|
|
@ -393,8 +354,8 @@ def DSYR_blas(A, x, alpha=1.):
|
||||||
A_ = A.ctypes.data_as(ctypes.c_void_p)
|
A_ = A.ctypes.data_as(ctypes.c_void_p)
|
||||||
x_ = x.ctypes.data_as(ctypes.c_void_p)
|
x_ = x.ctypes.data_as(ctypes.c_void_p)
|
||||||
INCX = c_int(1)
|
INCX = c_int(1)
|
||||||
dsyr(byref(UPLO), byref(N), byref(ALPHA),
|
blas.dsyr(byref(UPLO), byref(N), byref(ALPHA),
|
||||||
x_, byref(INCX), A_, byref(LDA))
|
x_, byref(INCX), A_, byref(LDA))
|
||||||
symmetrify(A, upper=True)
|
symmetrify(A, upper=True)
|
||||||
|
|
||||||
def DSYR_numpy(A, x, alpha=1.):
|
def DSYR_numpy(A, x, alpha=1.):
|
||||||
|
|
@ -411,10 +372,8 @@ def DSYR_numpy(A, x, alpha=1.):
|
||||||
|
|
||||||
|
|
||||||
def DSYR(*args, **kwargs):
|
def DSYR(*args, **kwargs):
|
||||||
if _blas_available:
|
return DSYR_blas(*args, **kwargs)
|
||||||
return DSYR_blas(*args, **kwargs)
|
|
||||||
else:
|
|
||||||
return DSYR_numpy(*args, **kwargs)
|
|
||||||
|
|
||||||
def symmetrify(A, upper=False):
|
def symmetrify(A, upper=False):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue