mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-11 21:12:38 +02:00
FIX: now Scipy 0.16 is required, removing fixes for older versions. Accessing blas through the scipy interface
This commit is contained in:
parent
f221a3b1fa
commit
ba2ea3eb73
1 changed files with 12 additions and 53 deletions
|
|
@ -7,48 +7,16 @@
|
|||
|
||||
import numpy as np
|
||||
from scipy import linalg
|
||||
import types
|
||||
from scipy.linalg import lapack, blas
|
||||
|
||||
import ctypes
|
||||
from ctypes import byref, c_char, c_int, c_double # TODO
|
||||
import scipy
|
||||
import warnings
|
||||
import os
|
||||
|
||||
from .config import config
|
||||
import logging
|
||||
from . import linalg_cython
|
||||
|
||||
|
||||
_scipyversion = np.float64((scipy.__version__).split('.')[:2])
|
||||
_fix_dpotri_scipy_bug = True
|
||||
if np.all(_scipyversion >= np.array([0, 14])):
|
||||
from scipy.linalg import lapack
|
||||
_fix_dpotri_scipy_bug = False
|
||||
elif np.all(_scipyversion >= np.array([0, 12])):
|
||||
#import scipy.linalg.lapack.clapack as lapack
|
||||
from scipy.linalg import lapack
|
||||
else:
|
||||
from scipy.linalg.lapack import flapack as lapack
|
||||
|
||||
if config.getboolean('anaconda', 'installed') and config.getboolean('anaconda', 'MKL'):
|
||||
try:
|
||||
anaconda_path = str(config.get('anaconda', 'location'))
|
||||
mkl_rt = ctypes.cdll.LoadLibrary(os.path.join(anaconda_path, 'DLLs', 'mkl_rt.dll'))
|
||||
dsyrk = mkl_rt.dsyrk
|
||||
dsyr = mkl_rt.dsyr
|
||||
_blas_available = True
|
||||
print('anaconda installed and mkl is loaded')
|
||||
except:
|
||||
_blas_available = False
|
||||
else:
|
||||
try:
|
||||
_blaslib = ctypes.cdll.LoadLibrary(np.core._dotblas.__file__) # @UndefinedVariable
|
||||
dsyrk = _blaslib.dsyrk_
|
||||
dsyr = _blaslib.dsyr_
|
||||
_blas_available = True
|
||||
except AttributeError as e:
|
||||
_blas_available = False
|
||||
warnings.warn("warning: caught this exception:" + str(e))
|
||||
|
||||
def force_F_ordered_symmetric(A):
|
||||
"""
|
||||
return a F ordered version of A, assuming A is symmetric
|
||||
|
|
@ -169,9 +137,6 @@ def dpotri(A, lower=1):
|
|||
:returns: A inverse
|
||||
|
||||
"""
|
||||
if _fix_dpotri_scipy_bug:
|
||||
assert lower==1, "scipy linalg behaviour is very weird. please use lower, fortran ordered arrays"
|
||||
lower = 0
|
||||
|
||||
A = force_F_ordered(A)
|
||||
R, info = lapack.dpotri(A, lower=lower) #needs to be zero here, seems to be a scipy bug
|
||||
|
|
@ -300,8 +265,8 @@ def pca(Y, input_dim):
|
|||
Z = linalg.svd(Y - Y.mean(axis=0), full_matrices=False)
|
||||
[X, W] = [Z[0][:, 0:input_dim], np.dot(np.diag(Z[1]), Z[2]).T[:, 0:input_dim]]
|
||||
v = X.std(axis=0)
|
||||
X /= v;
|
||||
W *= v;
|
||||
X /= v
|
||||
W *= v
|
||||
return X, W.T
|
||||
|
||||
def ppca(Y, Q, iterations=100):
|
||||
|
|
@ -362,19 +327,15 @@ def tdot_blas(mat, out=None):
|
|||
BETA = c_double(0.0)
|
||||
C = out.ctypes.data_as(ctypes.c_void_p)
|
||||
LDC = c_int(np.max(out.strides) // 8)
|
||||
dsyrk(byref(UPLO), byref(TRANS), byref(N), byref(K),
|
||||
byref(ALPHA), A, byref(LDA), byref(BETA), C, byref(LDC))
|
||||
blas.dsyrk(byref(UPLO), byref(TRANS), byref(N), byref(K),
|
||||
byref(ALPHA), A, byref(LDA), byref(BETA), C, byref(LDC))
|
||||
|
||||
symmetrify(out, upper=True)
|
||||
|
||||
|
||||
return np.ascontiguousarray(out)
|
||||
|
||||
def tdot(*args, **kwargs):
|
||||
if _blas_available:
|
||||
return tdot_blas(*args, **kwargs)
|
||||
else:
|
||||
return tdot_numpy(*args, **kwargs)
|
||||
return tdot_blas(*args, **kwargs)
|
||||
|
||||
def DSYR_blas(A, x, alpha=1.):
|
||||
"""
|
||||
|
|
@ -393,8 +354,8 @@ def DSYR_blas(A, x, alpha=1.):
|
|||
A_ = A.ctypes.data_as(ctypes.c_void_p)
|
||||
x_ = x.ctypes.data_as(ctypes.c_void_p)
|
||||
INCX = c_int(1)
|
||||
dsyr(byref(UPLO), byref(N), byref(ALPHA),
|
||||
x_, byref(INCX), A_, byref(LDA))
|
||||
blas.dsyr(byref(UPLO), byref(N), byref(ALPHA),
|
||||
x_, byref(INCX), A_, byref(LDA))
|
||||
symmetrify(A, upper=True)
|
||||
|
||||
def DSYR_numpy(A, x, alpha=1.):
|
||||
|
|
@ -411,10 +372,8 @@ def DSYR_numpy(A, x, alpha=1.):
|
|||
|
||||
|
||||
def DSYR(*args, **kwargs):
|
||||
if _blas_available:
|
||||
return DSYR_blas(*args, **kwargs)
|
||||
else:
|
||||
return DSYR_numpy(*args, **kwargs)
|
||||
return DSYR_blas(*args, **kwargs)
|
||||
|
||||
|
||||
def symmetrify(A, upper=False):
|
||||
"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue