mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-24 14:15:14 +02:00
Minor change to updating of gradients in multioutput kernel
This commit is contained in:
parent
0b5e8d895e
commit
f4629c89cd
2 changed files with 10 additions and 4 deletions
|
|
@ -201,6 +201,12 @@ class Kern(Parameterized):
|
||||||
def reset_gradients(self):
|
def reset_gradients(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_gradient(self):
|
||||||
|
return self.gradient.copy()
|
||||||
|
|
||||||
|
def append_gradient(self,gradient):
|
||||||
|
self.gradient += gradient
|
||||||
|
|
||||||
def update_gradients_expectations(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
|
def update_gradients_expectations(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
|
||||||
"""
|
"""
|
||||||
Set the gradients of all parameters when doing inference with
|
Set the gradients of all parameters when doing inference with
|
||||||
|
|
|
||||||
|
|
@ -94,14 +94,14 @@ class MultioutputKern(CombinationKernel):
|
||||||
return target
|
return target
|
||||||
|
|
||||||
def _update_gradients_full_wrapper(self, cov_struct, dL_dK, X, X2):
|
def _update_gradients_full_wrapper(self, cov_struct, dL_dK, X, X2):
|
||||||
gradient = cov_struct['kern'].gradient.copy()
|
gradient = cov_struct['kern'].get_gradient()
|
||||||
cov_struct['update_gradients_full'](dL_dK, X, X2)
|
cov_struct['update_gradients_full'](dL_dK, X, X2)
|
||||||
cov_struct['kern'].gradient += gradient
|
cov_struct['kern'].append_gradient(gradient)
|
||||||
|
|
||||||
def _update_gradients_diag_wrapper(self, kern, dL_dKdiag, X):
|
def _update_gradients_diag_wrapper(self, kern, dL_dKdiag, X):
|
||||||
gradient = kern.gradient.copy()
|
gradient = kern.get_gradient()
|
||||||
kern.update_gradients_diag(dL_dKdiag, X)
|
kern.update_gradients_diag(dL_dKdiag, X)
|
||||||
kern.gradient += gradient
|
kern.append_gradient(gradient)
|
||||||
|
|
||||||
def reset_gradients(self):
|
def reset_gradients(self):
|
||||||
for kern in self.kern: kern.reset_gradients()
|
for kern in self.kern: kern.reset_gradients()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue