linear and rbf fix for variational gradients in Z

This commit is contained in:
Max Zwiessele 2014-02-21 12:29:28 +00:00
parent 0dc9a32ba3
commit 8b8ca5544f
5 changed files with 102 additions and 113 deletions

View file

@ -140,9 +140,8 @@ class Linear(Kern):
if self.ARD: grad += tmp.sum(0).sum(0).sum(0)
else: grad += tmp.sum()
#from Kmm
self.update_gradients_full(dL_dpsi1, mu, Z)
grad += self.variances.gradient
self._set_gradient(grad)
self.update_gradients_full(dL_dKmm, Z, None)
self.variances.gradient += grad
def gradients_Z_variational(self, dL_dKmm, dL_dpsi0, dL_dpsi1, dL_dpsi2, mu, S, Z):
# Kmm
@ -221,7 +220,6 @@ class Linear(Kern):
def _weave_dpsi2_dZ(self, dL_dpsi2, Z, mu, S, target):
AZA = self.variances*self._ZAinner(mu, S, Z)
code="""
int n,m,mm,q;
@ -230,7 +228,7 @@ class Linear(Kern):
for(q=0;q<input_dim;q++){
for(mm=0;mm<num_inducing;mm++){
for(n=0;n<N;n++){
target(m,q) += dL_dpsi2(n,m,mm)*AZA(n,mm,q);
target(m,q) += 2*dL_dpsi2(n,m,mm)*AZA(n,mm,q);
}
}
}
@ -245,7 +243,7 @@ class Linear(Kern):
'extra_link_args' : ['-lgomp']}
N,num_inducing,input_dim = mu.shape[0],Z.shape[0],mu.shape[1]
mu, AZA, target, dL_dpsi2 = param_to_array(mu, AZA, target, dL_dpsi2)
mu = param_to_array(mu)
weave.inline(code, support_code=support_code, libraries=['gomp'],
arg_names=['N','num_inducing','input_dim','AZA','target','dL_dpsi2'],
type_converters=weave.converters.blitz,**weave_options)