Minor change to updating of gradients in multioutput kernel

This commit is contained in:
Siivola Eero 2018-09-05 14:11:29 +03:00
parent 0b5e8d895e
commit f4629c89cd
2 changed files with 10 additions and 4 deletions

View file

@ -201,6 +201,12 @@ class Kern(Parameterized):
def reset_gradients(self): def reset_gradients(self):
raise NotImplementedError raise NotImplementedError
def get_gradient(self):
return self.gradient.copy()
def append_gradient(self,gradient):
self.gradient += gradient
def update_gradients_expectations(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior): def update_gradients_expectations(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
""" """
Set the gradients of all parameters when doing inference with Set the gradients of all parameters when doing inference with

View file

@ -94,14 +94,14 @@ class MultioutputKern(CombinationKernel):
return target return target
def _update_gradients_full_wrapper(self, cov_struct, dL_dK, X, X2): def _update_gradients_full_wrapper(self, cov_struct, dL_dK, X, X2):
gradient = cov_struct['kern'].gradient.copy() gradient = cov_struct['kern'].get_gradient()
cov_struct['update_gradients_full'](dL_dK, X, X2) cov_struct['update_gradients_full'](dL_dK, X, X2)
cov_struct['kern'].gradient += gradient cov_struct['kern'].append_gradient(gradient)
def _update_gradients_diag_wrapper(self, kern, dL_dKdiag, X): def _update_gradients_diag_wrapper(self, kern, dL_dKdiag, X):
gradient = kern.gradient.copy() gradient = kern.get_gradient()
kern.update_gradients_diag(dL_dKdiag, X) kern.update_gradients_diag(dL_dKdiag, X)
kern.gradient += gradient kern.append_gradient(gradient)
def reset_gradients(self): def reset_gradients(self):
for kern in self.kern: kern.reset_gradients() for kern in self.kern: kern.reset_gradients()