FIX: now Scipy 0.16 is required, removing fixes for older versions. Accessing blas through the scipy interface

This commit is contained in:
David Menéndez Hurtado 2015-08-24 12:56:18 +02:00
parent f221a3b1fa
commit ba2ea3eb73

View file

@ -7,48 +7,16 @@
import numpy as np
from scipy import linalg
import types
from scipy.linalg import lapack, blas
import ctypes
from ctypes import byref, c_char, c_int, c_double # TODO
import scipy
import warnings
import os
from .config import config
import logging
from . import linalg_cython
_scipyversion = np.float64((scipy.__version__).split('.')[:2])
_fix_dpotri_scipy_bug = True
if np.all(_scipyversion >= np.array([0, 14])):
from scipy.linalg import lapack
_fix_dpotri_scipy_bug = False
elif np.all(_scipyversion >= np.array([0, 12])):
#import scipy.linalg.lapack.clapack as lapack
from scipy.linalg import lapack
else:
from scipy.linalg.lapack import flapack as lapack
if config.getboolean('anaconda', 'installed') and config.getboolean('anaconda', 'MKL'):
try:
anaconda_path = str(config.get('anaconda', 'location'))
mkl_rt = ctypes.cdll.LoadLibrary(os.path.join(anaconda_path, 'DLLs', 'mkl_rt.dll'))
dsyrk = mkl_rt.dsyrk
dsyr = mkl_rt.dsyr
_blas_available = True
print('anaconda installed and mkl is loaded')
except:
_blas_available = False
else:
try:
_blaslib = ctypes.cdll.LoadLibrary(np.core._dotblas.__file__) # @UndefinedVariable
dsyrk = _blaslib.dsyrk_
dsyr = _blaslib.dsyr_
_blas_available = True
except AttributeError as e:
_blas_available = False
warnings.warn("warning: caught this exception:" + str(e))
def force_F_ordered_symmetric(A):
"""
return a F ordered version of A, assuming A is symmetric
@ -169,9 +137,6 @@ def dpotri(A, lower=1):
:returns: A inverse
"""
if _fix_dpotri_scipy_bug:
assert lower==1, "scipy linalg behaviour is very weird. please use lower, fortran ordered arrays"
lower = 0
A = force_F_ordered(A)
R, info = lapack.dpotri(A, lower=lower) #needs to be zero here, seems to be a scipy bug
@ -300,8 +265,8 @@ def pca(Y, input_dim):
Z = linalg.svd(Y - Y.mean(axis=0), full_matrices=False)
[X, W] = [Z[0][:, 0:input_dim], np.dot(np.diag(Z[1]), Z[2]).T[:, 0:input_dim]]
v = X.std(axis=0)
X /= v;
W *= v;
X /= v
W *= v
return X, W.T
def ppca(Y, Q, iterations=100):
@ -362,19 +327,15 @@ def tdot_blas(mat, out=None):
BETA = c_double(0.0)
C = out.ctypes.data_as(ctypes.c_void_p)
LDC = c_int(np.max(out.strides) // 8)
dsyrk(byref(UPLO), byref(TRANS), byref(N), byref(K),
byref(ALPHA), A, byref(LDA), byref(BETA), C, byref(LDC))
blas.dsyrk(byref(UPLO), byref(TRANS), byref(N), byref(K),
byref(ALPHA), A, byref(LDA), byref(BETA), C, byref(LDC))
symmetrify(out, upper=True)
return np.ascontiguousarray(out)
def tdot(*args, **kwargs):
if _blas_available:
return tdot_blas(*args, **kwargs)
else:
return tdot_numpy(*args, **kwargs)
return tdot_blas(*args, **kwargs)
def DSYR_blas(A, x, alpha=1.):
"""
@ -393,8 +354,8 @@ def DSYR_blas(A, x, alpha=1.):
A_ = A.ctypes.data_as(ctypes.c_void_p)
x_ = x.ctypes.data_as(ctypes.c_void_p)
INCX = c_int(1)
dsyr(byref(UPLO), byref(N), byref(ALPHA),
x_, byref(INCX), A_, byref(LDA))
blas.dsyr(byref(UPLO), byref(N), byref(ALPHA),
x_, byref(INCX), A_, byref(LDA))
symmetrify(A, upper=True)
def DSYR_numpy(A, x, alpha=1.):
@ -411,10 +372,8 @@ def DSYR_numpy(A, x, alpha=1.):
def DSYR(*args, **kwargs):
if _blas_available:
return DSYR_blas(*args, **kwargs)
else:
return DSYR_numpy(*args, **kwargs)
return DSYR_blas(*args, **kwargs)
def symmetrify(A, upper=False):
"""