From ce2884f0a7dc26087a5225bc92e39643920e3e16 Mon Sep 17 00:00:00 2001 From: James Hensman Date: Tue, 7 May 2013 18:02:10 +0100 Subject: [PATCH] weaved linear kern --- GPy/kern/linear.py | 107 +++++++++++++++++++++++++++++++++++---------- 1 file changed, 83 insertions(+), 24 deletions(-) diff --git a/GPy/kern/linear.py b/GPy/kern/linear.py index 396b1aec..16ef2499 100644 --- a/GPy/kern/linear.py +++ b/GPy/kern/linear.py @@ -5,6 +5,7 @@ from kernpart import kernpart import numpy as np from ..util.linalg import tdot +from scipy import weave class linear(kernpart): """ @@ -171,33 +172,91 @@ class linear(kernpart): self._psi_computations(Z, mu, S) AZZA = self.ZA.T[:, None, :, None] * self.ZA[None, :, None, :] AZZA = AZZA + AZZA.swapaxes(1, 2) - target_S += (dL_dpsi2[:, :, :, None] * self.ZA[None, :, None, :] * self.ZA[None, None, :, :]).sum(1).sum(1) - dpsi2_dmu = (dL_dpsi2[:, :, :, None] * np.tensordot(mu, AZZA, (-1, 0))).sum(1).sum(1) - target_mu += dpsi2_dmu + AZZA_2 = AZZA/2. + #muAZZA = np.tensordot(mu,AZZA,(-1,0)) + #target_mu_dummy, target_S_dummy = np.zeros_like(target_mu), np.zeros_like(target_S) + #target_mu_dummy += (dL_dpsi2[:, :, :, None] * muAZZA).sum(1).sum(1) + #target_S_dummy += (dL_dpsi2[:, :, :, None] * self.ZA[None, :, None, :] * self.ZA[None, None, :, :]).sum(1).sum(1) + + #Using weave, we can exploiut the symmetry of this problem: + code = """ + int n, m, mm,q,qq; + double factor,tmp; + #pragma omp parallel for private(m,mm,q,qq,factor,tmp) + for(n=0;n + #include + """ + weave_options = {'headers' : [''], + 'extra_compile_args': ['-fopenmp -O3'], #-march=native'], + 'extra_link_args' : ['-lgomp']} + + N,M,Q = mu.shape[0],Z.shape[0],mu.shape[1] + weave.inline(code, support_code=support_code, libraries=['gomp'], + arg_names=['N','M','Q','mu','AZZA','AZZA_2','target_mu','target_S','dL_dpsi2'], + type_converters=weave.converters.blitz,**weave_options) + def dpsi2_dZ(self, dL_dpsi2, Z, mu, S, target): self._psi_computations(Z, mu, S) -# mu2_S = np.sum(self.mu2_S, 0) # Q, -# import ipdb;ipdb.set_trace() -# psi2_dZ_real = np.zeros((mu.shape[0], Z.shape[0], Z.shape[1])) -# for n in range(mu.shape[0]): -# for m in range(Z.shape[0]): -# tmp = self.variances * (tdot(self._mu[n:n + 1].T) + np.diag(S[n])) -# psi2_dZ_real[n, m, :] = np.dot(tmp, ( -# self._Z[m:m + 1] * self.variances).T).T -# tmp = self._Z[m:m + 1] * self.variances -# tmp = np.dot(tmp, (tdot(self._mu[n:n + 1].T) + np.diag(S[n]))) -# psi2_dZ_real[n, m, :] = tmp * self.variances -# for m_prime in range(Z.shape[0]): -# if m == m_prime: -# psi2_dZ_real[n, m, :] *= 2 -# prod = (dL_dpsi2[:, :, :, None] * np.eye(Z.shape[0])[None, :, :, None] * (self.ZAinner * self.variances).swapaxes(0, 1)[:, :, None, :]) -# psi2_dZ = prod.swapaxes(1, 2) + prod - psi2_dZ = dL_dpsi2[:, :, :, None] * self.variances * self.ZAinner[:, :, None, :] - target += psi2_dZ.sum(0).sum(0) -# import ipdb;ipdb.set_trace() -# psi2_dZ_old = (dL_dpsi2[:, :, :, None] * (self.mu2_S[:, None, None, :] * (Z * np.square(self.variances)[None, :])[None, None, :, :])).sum(0).sum(1) -# target += (dL_dpsi2[:, :, :, None] * psi2_dZ_real[:, :, None, :]).sum(0).sum(0) * 2 # (self.variances * np.dot(self.inner, self.ZA.T)).sum(1) + #psi2_dZ = dL_dpsi2[:, :, :, None] * self.variances * self.ZAinner[:, :, None, :] + #dummy_target = np.zeros_like(target) + #dummy_target += psi2_dZ.sum(0).sum(0) + + AZA = self.variances*self.ZAinner + code=""" + int n,m,mm,q; + #pragma omp parallel for private(n,mm,q) + for(m=0;m + #include + """ + weave_options = {'headers' : [''], + 'extra_compile_args': ['-fopenmp -O3'], #-march=native'], + 'extra_link_args' : ['-lgomp']} + + N,M,Q = mu.shape[0],Z.shape[0],mu.shape[1] + weave.inline(code, support_code=support_code, libraries=['gomp'], + arg_names=['N','M','Q','AZA','target','dL_dpsi2'], + type_converters=weave.converters.blitz,**weave_options) + + + + #---------------------------------------# # Precomputations #