mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-09 03:52:39 +02:00
start of cythoning coregionalize
This commit is contained in:
parent
2e8ce34ee0
commit
c00f76d250
3 changed files with 56 additions and 1 deletions
|
|
@ -111,6 +111,11 @@ class Coregionalize(Kern):
|
||||||
weave.inline(code, ['target', 'index', 'index2', 'N', 'num_inducing', 'B', 'output_dim'])
|
weave.inline(code, ['target', 'index', 'index2', 'N', 'num_inducing', 'B', 'output_dim'])
|
||||||
return target
|
return target
|
||||||
|
|
||||||
|
def _K_cython(self, X, X2=None):
|
||||||
|
if X2 is None:
|
||||||
|
return coregionalize_cython.K_symmetric(self.B, X[:,0])
|
||||||
|
return coregionalize_cython.K_asymmetric(self.B, X[:,0], X2[:,0])
|
||||||
|
|
||||||
|
|
||||||
def Kdiag(self, X):
|
def Kdiag(self, X):
|
||||||
return np.diag(self.B)[np.asarray(X, dtype=np.int).flatten()]
|
return np.diag(self.B)[np.asarray(X, dtype=np.int).flatten()]
|
||||||
|
|
@ -164,6 +169,11 @@ class Coregionalize(Kern):
|
||||||
dL_dK_small[j,i] = tmp1[:,index2==j].sum()
|
dL_dK_small[j,i] = tmp1[:,index2==j].sum()
|
||||||
return dL_dK_small
|
return dL_dK_small
|
||||||
|
|
||||||
|
def gradient_reduce_cython(self, dL_dK, index, index2):
|
||||||
|
index, index2 = index[:,0], index2[:,0]
|
||||||
|
return coregionalize_cython.gradient_reduce(self.output_dim, dL_dK, index, index2
|
||||||
|
|
||||||
|
|
||||||
def update_gradients_diag(self, dL_dKdiag, X):
|
def update_gradients_diag(self, dL_dKdiag, X):
|
||||||
index = np.asarray(X, dtype=np.int).flatten()
|
index = np.asarray(X, dtype=np.int).flatten()
|
||||||
dL_dKdiag_small = np.array([dL_dKdiag[index==i].sum() for i in range(self.output_dim)])
|
dL_dKdiag_small = np.array([dL_dKdiag[index==i].sum() for i in range(self.output_dim)])
|
||||||
|
|
|
||||||
41
GPy/kern/_src/coregionalize_cython.pyx
Normal file
41
GPy/kern/_src/coregionalize_cython.pyx
Normal file
|
|
@ -0,0 +1,41 @@
|
||||||
|
import cython
|
||||||
|
import numpy as np
|
||||||
|
cimport numpy as np
|
||||||
|
|
||||||
|
@cython.boundscheck(False)
|
||||||
|
@cython.wraparound(False)
|
||||||
|
@cython.nonecheck(False)
|
||||||
|
def K_symmetric(np.ndarray[double, ndim=2] B, np.ndarray[int, ndim=1] X):
|
||||||
|
N = X.size
|
||||||
|
K = np.zeros((N, N))
|
||||||
|
for n in range(N):
|
||||||
|
for m in range(N):
|
||||||
|
K[n,m] = B[X[n],X[m]]
|
||||||
|
return K
|
||||||
|
|
||||||
|
@cython.boundscheck(False)
|
||||||
|
@cython.wraparound(False)
|
||||||
|
@cython.nonecheck(False)
|
||||||
|
def K_asymmetric(np.ndarray[double, ndim=2] B, np.ndarray[int, ndim=1] X, np.ndarray[int, ndim=1] X2):
|
||||||
|
N = X.size
|
||||||
|
M = X2.size
|
||||||
|
K = np.zeros((N, M))
|
||||||
|
for n in range(N):
|
||||||
|
for m in range(M):
|
||||||
|
K[n,m] = B[X[n],X2[m]]
|
||||||
|
return K
|
||||||
|
|
||||||
|
@cython.boundscheck(False)
|
||||||
|
@cython.wraparound(False)
|
||||||
|
@cython.nonecheck(False)
|
||||||
|
def gradient_reduce(int D, np.ndarray[double, ndim=2] dL_dK, np.ndarray[int, ndim=1] index, np.ndarray[int, ndim=1] index2):
|
||||||
|
dL_dK_small = np.zeros((D, D))
|
||||||
|
N = index.size
|
||||||
|
M = index2.size
|
||||||
|
for i in range(M):
|
||||||
|
for j in range(N):
|
||||||
|
dL_dK_small[index[j] + D*index2[i]] += dL_dK[i+j*M];
|
||||||
|
return dL_dK_small
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
6
setup.py
6
setup.py
|
|
@ -17,7 +17,11 @@ ext_mods = [Extension(name='GPy.kern._src.stationary_cython',
|
||||||
sources=['GPy/kern/_src/stationary_cython.c','GPy/kern/_src/stationary_utils.c'],
|
sources=['GPy/kern/_src/stationary_cython.c','GPy/kern/_src/stationary_utils.c'],
|
||||||
include_dirs=[np.get_include()],
|
include_dirs=[np.get_include()],
|
||||||
extra_compile_args=compile_flags,
|
extra_compile_args=compile_flags,
|
||||||
extra_link_args = ['-lgomp'])]
|
extra_link_args = ['-lgomp']),
|
||||||
|
Extension(name='GPy.kern._src.coregionalize_cython',
|
||||||
|
sources=['GPy/kern/_src/coregionalize_cython.c','GPy/kern/_src/coregionalize_cython.c'],
|
||||||
|
include_dirs=[np.get_include()],
|
||||||
|
extra_compile_args=compile_flags)]
|
||||||
|
|
||||||
setup(name = 'GPy',
|
setup(name = 'GPy',
|
||||||
version = version,
|
version = version,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue