mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-11 15:15:15 +02:00
weaved linear kern
This commit is contained in:
parent
7f138b8b01
commit
ce2884f0a7
1 changed files with 83 additions and 24 deletions
|
|
@ -5,6 +5,7 @@
|
|||
from kernpart import kernpart
|
||||
import numpy as np
|
||||
from ..util.linalg import tdot
|
||||
from scipy import weave
|
||||
|
||||
class linear(kernpart):
|
||||
"""
|
||||
|
|
@ -171,33 +172,91 @@ class linear(kernpart):
|
|||
self._psi_computations(Z, mu, S)
|
||||
AZZA = self.ZA.T[:, None, :, None] * self.ZA[None, :, None, :]
|
||||
AZZA = AZZA + AZZA.swapaxes(1, 2)
|
||||
target_S += (dL_dpsi2[:, :, :, None] * self.ZA[None, :, None, :] * self.ZA[None, None, :, :]).sum(1).sum(1)
|
||||
dpsi2_dmu = (dL_dpsi2[:, :, :, None] * np.tensordot(mu, AZZA, (-1, 0))).sum(1).sum(1)
|
||||
target_mu += dpsi2_dmu
|
||||
AZZA_2 = AZZA/2.
|
||||
#muAZZA = np.tensordot(mu,AZZA,(-1,0))
|
||||
#target_mu_dummy, target_S_dummy = np.zeros_like(target_mu), np.zeros_like(target_S)
|
||||
#target_mu_dummy += (dL_dpsi2[:, :, :, None] * muAZZA).sum(1).sum(1)
|
||||
#target_S_dummy += (dL_dpsi2[:, :, :, None] * self.ZA[None, :, None, :] * self.ZA[None, None, :, :]).sum(1).sum(1)
|
||||
|
||||
#Using weave, we can exploiut the symmetry of this problem:
|
||||
code = """
|
||||
int n, m, mm,q,qq;
|
||||
double factor,tmp;
|
||||
#pragma omp parallel for private(m,mm,q,qq,factor,tmp)
|
||||
for(n=0;n<N;n++){
|
||||
for(m=0;m<M;m++){
|
||||
for(mm=0;mm<=m;mm++){
|
||||
//add in a factor of 2 for the off-diagonal terms (and then count them only once)
|
||||
if(m==mm)
|
||||
factor = dL_dpsi2(n,m,mm);
|
||||
else
|
||||
factor = 2.0*dL_dpsi2(n,m,mm);
|
||||
|
||||
for(q=0;q<Q;q++){
|
||||
|
||||
//take the dot product of mu[n,:] and AZZA[:,m,mm,q] TODO: blas!
|
||||
tmp = 0.0;
|
||||
for(qq=0;qq<Q;qq++){
|
||||
tmp += mu(n,qq)*AZZA(qq,m,mm,q);
|
||||
}
|
||||
|
||||
target_mu(n,q) += factor*tmp;
|
||||
target_S(n,q) += factor*AZZA_2(q,m,mm,q);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
support_code = """
|
||||
#include <omp.h>
|
||||
#include <math.h>
|
||||
"""
|
||||
weave_options = {'headers' : ['<omp.h>'],
|
||||
'extra_compile_args': ['-fopenmp -O3'], #-march=native'],
|
||||
'extra_link_args' : ['-lgomp']}
|
||||
|
||||
N,M,Q = mu.shape[0],Z.shape[0],mu.shape[1]
|
||||
weave.inline(code, support_code=support_code, libraries=['gomp'],
|
||||
arg_names=['N','M','Q','mu','AZZA','AZZA_2','target_mu','target_S','dL_dpsi2'],
|
||||
type_converters=weave.converters.blitz,**weave_options)
|
||||
|
||||
|
||||
def dpsi2_dZ(self, dL_dpsi2, Z, mu, S, target):
|
||||
self._psi_computations(Z, mu, S)
|
||||
# mu2_S = np.sum(self.mu2_S, 0) # Q,
|
||||
# import ipdb;ipdb.set_trace()
|
||||
# psi2_dZ_real = np.zeros((mu.shape[0], Z.shape[0], Z.shape[1]))
|
||||
# for n in range(mu.shape[0]):
|
||||
# for m in range(Z.shape[0]):
|
||||
# tmp = self.variances * (tdot(self._mu[n:n + 1].T) + np.diag(S[n]))
|
||||
# psi2_dZ_real[n, m, :] = np.dot(tmp, (
|
||||
# self._Z[m:m + 1] * self.variances).T).T
|
||||
# tmp = self._Z[m:m + 1] * self.variances
|
||||
# tmp = np.dot(tmp, (tdot(self._mu[n:n + 1].T) + np.diag(S[n])))
|
||||
# psi2_dZ_real[n, m, :] = tmp * self.variances
|
||||
# for m_prime in range(Z.shape[0]):
|
||||
# if m == m_prime:
|
||||
# psi2_dZ_real[n, m, :] *= 2
|
||||
# prod = (dL_dpsi2[:, :, :, None] * np.eye(Z.shape[0])[None, :, :, None] * (self.ZAinner * self.variances).swapaxes(0, 1)[:, :, None, :])
|
||||
# psi2_dZ = prod.swapaxes(1, 2) + prod
|
||||
psi2_dZ = dL_dpsi2[:, :, :, None] * self.variances * self.ZAinner[:, :, None, :]
|
||||
target += psi2_dZ.sum(0).sum(0)
|
||||
# import ipdb;ipdb.set_trace()
|
||||
# psi2_dZ_old = (dL_dpsi2[:, :, :, None] * (self.mu2_S[:, None, None, :] * (Z * np.square(self.variances)[None, :])[None, None, :, :])).sum(0).sum(1)
|
||||
# target += (dL_dpsi2[:, :, :, None] * psi2_dZ_real[:, :, None, :]).sum(0).sum(0) * 2 # (self.variances * np.dot(self.inner, self.ZA.T)).sum(1)
|
||||
#psi2_dZ = dL_dpsi2[:, :, :, None] * self.variances * self.ZAinner[:, :, None, :]
|
||||
#dummy_target = np.zeros_like(target)
|
||||
#dummy_target += psi2_dZ.sum(0).sum(0)
|
||||
|
||||
AZA = self.variances*self.ZAinner
|
||||
code="""
|
||||
int n,m,mm,q;
|
||||
#pragma omp parallel for private(n,mm,q)
|
||||
for(m=0;m<M;m++){
|
||||
for(q=0;q<Q;q++){
|
||||
for(mm=0;mm<M;mm++){
|
||||
for(n=0;n<N;n++){
|
||||
target(m,q) += dL_dpsi2(n,m,mm)*AZA(n,mm,q);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
support_code = """
|
||||
#include <omp.h>
|
||||
#include <math.h>
|
||||
"""
|
||||
weave_options = {'headers' : ['<omp.h>'],
|
||||
'extra_compile_args': ['-fopenmp -O3'], #-march=native'],
|
||||
'extra_link_args' : ['-lgomp']}
|
||||
|
||||
N,M,Q = mu.shape[0],Z.shape[0],mu.shape[1]
|
||||
weave.inline(code, support_code=support_code, libraries=['gomp'],
|
||||
arg_names=['N','M','Q','AZA','target','dL_dpsi2'],
|
||||
type_converters=weave.converters.blitz,**weave_options)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
#---------------------------------------#
|
||||
# Precomputations #
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue