mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-24 14:15:14 +02:00
linear and rbf fix for variational gradients in Z
This commit is contained in:
parent
0dc9a32ba3
commit
8b8ca5544f
5 changed files with 102 additions and 113 deletions
|
|
@ -140,9 +140,8 @@ class Linear(Kern):
|
|||
if self.ARD: grad += tmp.sum(0).sum(0).sum(0)
|
||||
else: grad += tmp.sum()
|
||||
#from Kmm
|
||||
self.update_gradients_full(dL_dpsi1, mu, Z)
|
||||
grad += self.variances.gradient
|
||||
self._set_gradient(grad)
|
||||
self.update_gradients_full(dL_dKmm, Z, None)
|
||||
self.variances.gradient += grad
|
||||
|
||||
def gradients_Z_variational(self, dL_dKmm, dL_dpsi0, dL_dpsi1, dL_dpsi2, mu, S, Z):
|
||||
# Kmm
|
||||
|
|
@ -221,7 +220,6 @@ class Linear(Kern):
|
|||
|
||||
|
||||
def _weave_dpsi2_dZ(self, dL_dpsi2, Z, mu, S, target):
|
||||
|
||||
AZA = self.variances*self._ZAinner(mu, S, Z)
|
||||
code="""
|
||||
int n,m,mm,q;
|
||||
|
|
@ -230,7 +228,7 @@ class Linear(Kern):
|
|||
for(q=0;q<input_dim;q++){
|
||||
for(mm=0;mm<num_inducing;mm++){
|
||||
for(n=0;n<N;n++){
|
||||
target(m,q) += dL_dpsi2(n,m,mm)*AZA(n,mm,q);
|
||||
target(m,q) += 2*dL_dpsi2(n,m,mm)*AZA(n,mm,q);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -245,7 +243,7 @@ class Linear(Kern):
|
|||
'extra_link_args' : ['-lgomp']}
|
||||
|
||||
N,num_inducing,input_dim = mu.shape[0],Z.shape[0],mu.shape[1]
|
||||
mu, AZA, target, dL_dpsi2 = param_to_array(mu, AZA, target, dL_dpsi2)
|
||||
mu = param_to_array(mu)
|
||||
weave.inline(code, support_code=support_code, libraries=['gomp'],
|
||||
arg_names=['N','num_inducing','input_dim','AZA','target','dL_dpsi2'],
|
||||
type_converters=weave.converters.blitz,**weave_options)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue