FIX: now Scipy 0.16 is required, removing fixes for older versions. Accessing blas through the scipy interface

This commit is contained in:
David Menéndez Hurtado 2015-08-24 12:56:18 +02:00
parent f221a3b1fa
commit ba2ea3eb73

View file

@ -7,48 +7,16 @@
import numpy as np import numpy as np
from scipy import linalg from scipy import linalg
import types from scipy.linalg import lapack, blas
import ctypes import ctypes
from ctypes import byref, c_char, c_int, c_double # TODO from ctypes import byref, c_char, c_int, c_double # TODO
import scipy
import warnings
import os
from .config import config from .config import config
import logging import logging
from . import linalg_cython from . import linalg_cython
_scipyversion = np.float64((scipy.__version__).split('.')[:2])
_fix_dpotri_scipy_bug = True
if np.all(_scipyversion >= np.array([0, 14])):
from scipy.linalg import lapack
_fix_dpotri_scipy_bug = False
elif np.all(_scipyversion >= np.array([0, 12])):
#import scipy.linalg.lapack.clapack as lapack
from scipy.linalg import lapack
else:
from scipy.linalg.lapack import flapack as lapack
if config.getboolean('anaconda', 'installed') and config.getboolean('anaconda', 'MKL'):
try:
anaconda_path = str(config.get('anaconda', 'location'))
mkl_rt = ctypes.cdll.LoadLibrary(os.path.join(anaconda_path, 'DLLs', 'mkl_rt.dll'))
dsyrk = mkl_rt.dsyrk
dsyr = mkl_rt.dsyr
_blas_available = True
print('anaconda installed and mkl is loaded')
except:
_blas_available = False
else:
try:
_blaslib = ctypes.cdll.LoadLibrary(np.core._dotblas.__file__) # @UndefinedVariable
dsyrk = _blaslib.dsyrk_
dsyr = _blaslib.dsyr_
_blas_available = True
except AttributeError as e:
_blas_available = False
warnings.warn("warning: caught this exception:" + str(e))
def force_F_ordered_symmetric(A): def force_F_ordered_symmetric(A):
""" """
return a F ordered version of A, assuming A is symmetric return a F ordered version of A, assuming A is symmetric
@ -169,9 +137,6 @@ def dpotri(A, lower=1):
:returns: A inverse :returns: A inverse
""" """
if _fix_dpotri_scipy_bug:
assert lower==1, "scipy linalg behaviour is very weird. please use lower, fortran ordered arrays"
lower = 0
A = force_F_ordered(A) A = force_F_ordered(A)
R, info = lapack.dpotri(A, lower=lower) #needs to be zero here, seems to be a scipy bug R, info = lapack.dpotri(A, lower=lower) #needs to be zero here, seems to be a scipy bug
@ -300,8 +265,8 @@ def pca(Y, input_dim):
Z = linalg.svd(Y - Y.mean(axis=0), full_matrices=False) Z = linalg.svd(Y - Y.mean(axis=0), full_matrices=False)
[X, W] = [Z[0][:, 0:input_dim], np.dot(np.diag(Z[1]), Z[2]).T[:, 0:input_dim]] [X, W] = [Z[0][:, 0:input_dim], np.dot(np.diag(Z[1]), Z[2]).T[:, 0:input_dim]]
v = X.std(axis=0) v = X.std(axis=0)
X /= v; X /= v
W *= v; W *= v
return X, W.T return X, W.T
def ppca(Y, Q, iterations=100): def ppca(Y, Q, iterations=100):
@ -362,19 +327,15 @@ def tdot_blas(mat, out=None):
BETA = c_double(0.0) BETA = c_double(0.0)
C = out.ctypes.data_as(ctypes.c_void_p) C = out.ctypes.data_as(ctypes.c_void_p)
LDC = c_int(np.max(out.strides) // 8) LDC = c_int(np.max(out.strides) // 8)
dsyrk(byref(UPLO), byref(TRANS), byref(N), byref(K), blas.dsyrk(byref(UPLO), byref(TRANS), byref(N), byref(K),
byref(ALPHA), A, byref(LDA), byref(BETA), C, byref(LDC)) byref(ALPHA), A, byref(LDA), byref(BETA), C, byref(LDC))
symmetrify(out, upper=True) symmetrify(out, upper=True)
return np.ascontiguousarray(out) return np.ascontiguousarray(out)
def tdot(*args, **kwargs): def tdot(*args, **kwargs):
if _blas_available: return tdot_blas(*args, **kwargs)
return tdot_blas(*args, **kwargs)
else:
return tdot_numpy(*args, **kwargs)
def DSYR_blas(A, x, alpha=1.): def DSYR_blas(A, x, alpha=1.):
""" """
@ -393,8 +354,8 @@ def DSYR_blas(A, x, alpha=1.):
A_ = A.ctypes.data_as(ctypes.c_void_p) A_ = A.ctypes.data_as(ctypes.c_void_p)
x_ = x.ctypes.data_as(ctypes.c_void_p) x_ = x.ctypes.data_as(ctypes.c_void_p)
INCX = c_int(1) INCX = c_int(1)
dsyr(byref(UPLO), byref(N), byref(ALPHA), blas.dsyr(byref(UPLO), byref(N), byref(ALPHA),
x_, byref(INCX), A_, byref(LDA)) x_, byref(INCX), A_, byref(LDA))
symmetrify(A, upper=True) symmetrify(A, upper=True)
def DSYR_numpy(A, x, alpha=1.): def DSYR_numpy(A, x, alpha=1.):
@ -411,10 +372,8 @@ def DSYR_numpy(A, x, alpha=1.):
def DSYR(*args, **kwargs): def DSYR(*args, **kwargs):
if _blas_available: return DSYR_blas(*args, **kwargs)
return DSYR_blas(*args, **kwargs)
else:
return DSYR_numpy(*args, **kwargs)
def symmetrify(A, upper=False): def symmetrify(A, upper=False):
""" """