weaved linear kern

This commit is contained in:
James Hensman 2013-05-07 18:02:10 +01:00
parent 7f138b8b01
commit ce2884f0a7

View file

@ -5,6 +5,7 @@
from kernpart import kernpart
import numpy as np
from ..util.linalg import tdot
from scipy import weave
class linear(kernpart):
"""
@ -171,33 +172,91 @@ class linear(kernpart):
self._psi_computations(Z, mu, S)
AZZA = self.ZA.T[:, None, :, None] * self.ZA[None, :, None, :]
AZZA = AZZA + AZZA.swapaxes(1, 2)
target_S += (dL_dpsi2[:, :, :, None] * self.ZA[None, :, None, :] * self.ZA[None, None, :, :]).sum(1).sum(1)
dpsi2_dmu = (dL_dpsi2[:, :, :, None] * np.tensordot(mu, AZZA, (-1, 0))).sum(1).sum(1)
target_mu += dpsi2_dmu
AZZA_2 = AZZA/2.
#muAZZA = np.tensordot(mu,AZZA,(-1,0))
#target_mu_dummy, target_S_dummy = np.zeros_like(target_mu), np.zeros_like(target_S)
#target_mu_dummy += (dL_dpsi2[:, :, :, None] * muAZZA).sum(1).sum(1)
#target_S_dummy += (dL_dpsi2[:, :, :, None] * self.ZA[None, :, None, :] * self.ZA[None, None, :, :]).sum(1).sum(1)
#Using weave, we can exploiut the symmetry of this problem:
code = """
int n, m, mm,q,qq;
double factor,tmp;
#pragma omp parallel for private(m,mm,q,qq,factor,tmp)
for(n=0;n<N;n++){
for(m=0;m<M;m++){
for(mm=0;mm<=m;mm++){
//add in a factor of 2 for the off-diagonal terms (and then count them only once)
if(m==mm)
factor = dL_dpsi2(n,m,mm);
else
factor = 2.0*dL_dpsi2(n,m,mm);
for(q=0;q<Q;q++){
//take the dot product of mu[n,:] and AZZA[:,m,mm,q] TODO: blas!
tmp = 0.0;
for(qq=0;qq<Q;qq++){
tmp += mu(n,qq)*AZZA(qq,m,mm,q);
}
target_mu(n,q) += factor*tmp;
target_S(n,q) += factor*AZZA_2(q,m,mm,q);
}
}
}
}
"""
support_code = """
#include <omp.h>
#include <math.h>
"""
weave_options = {'headers' : ['<omp.h>'],
'extra_compile_args': ['-fopenmp -O3'], #-march=native'],
'extra_link_args' : ['-lgomp']}
N,M,Q = mu.shape[0],Z.shape[0],mu.shape[1]
weave.inline(code, support_code=support_code, libraries=['gomp'],
arg_names=['N','M','Q','mu','AZZA','AZZA_2','target_mu','target_S','dL_dpsi2'],
type_converters=weave.converters.blitz,**weave_options)
def dpsi2_dZ(self, dL_dpsi2, Z, mu, S, target):
self._psi_computations(Z, mu, S)
# mu2_S = np.sum(self.mu2_S, 0) # Q,
# import ipdb;ipdb.set_trace()
# psi2_dZ_real = np.zeros((mu.shape[0], Z.shape[0], Z.shape[1]))
# for n in range(mu.shape[0]):
# for m in range(Z.shape[0]):
# tmp = self.variances * (tdot(self._mu[n:n + 1].T) + np.diag(S[n]))
# psi2_dZ_real[n, m, :] = np.dot(tmp, (
# self._Z[m:m + 1] * self.variances).T).T
# tmp = self._Z[m:m + 1] * self.variances
# tmp = np.dot(tmp, (tdot(self._mu[n:n + 1].T) + np.diag(S[n])))
# psi2_dZ_real[n, m, :] = tmp * self.variances
# for m_prime in range(Z.shape[0]):
# if m == m_prime:
# psi2_dZ_real[n, m, :] *= 2
# prod = (dL_dpsi2[:, :, :, None] * np.eye(Z.shape[0])[None, :, :, None] * (self.ZAinner * self.variances).swapaxes(0, 1)[:, :, None, :])
# psi2_dZ = prod.swapaxes(1, 2) + prod
psi2_dZ = dL_dpsi2[:, :, :, None] * self.variances * self.ZAinner[:, :, None, :]
target += psi2_dZ.sum(0).sum(0)
# import ipdb;ipdb.set_trace()
# psi2_dZ_old = (dL_dpsi2[:, :, :, None] * (self.mu2_S[:, None, None, :] * (Z * np.square(self.variances)[None, :])[None, None, :, :])).sum(0).sum(1)
# target += (dL_dpsi2[:, :, :, None] * psi2_dZ_real[:, :, None, :]).sum(0).sum(0) * 2 # (self.variances * np.dot(self.inner, self.ZA.T)).sum(1)
#psi2_dZ = dL_dpsi2[:, :, :, None] * self.variances * self.ZAinner[:, :, None, :]
#dummy_target = np.zeros_like(target)
#dummy_target += psi2_dZ.sum(0).sum(0)
AZA = self.variances*self.ZAinner
code="""
int n,m,mm,q;
#pragma omp parallel for private(n,mm,q)
for(m=0;m<M;m++){
for(q=0;q<Q;q++){
for(mm=0;mm<M;mm++){
for(n=0;n<N;n++){
target(m,q) += dL_dpsi2(n,m,mm)*AZA(n,mm,q);
}
}
}
}
"""
support_code = """
#include <omp.h>
#include <math.h>
"""
weave_options = {'headers' : ['<omp.h>'],
'extra_compile_args': ['-fopenmp -O3'], #-march=native'],
'extra_link_args' : ['-lgomp']}
N,M,Q = mu.shape[0],Z.shape[0],mu.shape[1]
weave.inline(code, support_code=support_code, libraries=['gomp'],
arg_names=['N','M','Q','AZA','target','dL_dpsi2'],
type_converters=weave.converters.blitz,**weave_options)
#---------------------------------------#
# Precomputations #