diff --git a/GPy/kern/_src/mlp.py b/GPy/kern/_src/mlp.py index 4488ea82..e319b8b4 100644 --- a/GPy/kern/_src/mlp.py +++ b/GPy/kern/_src/mlp.py @@ -78,7 +78,7 @@ class MLP(Kern): *((vec1[:, None]+vec2[None, :])*self.weight_variance + 2*self.bias_variance + 2.))*base_cov_grad).sum() - def update_gradients_diag(self, X): + def update_gradients_diag(self, dL_dKdiag, X): self._K_diag_computations(X) self.variance.gradient = np.sum(self._K_diag_dvar*dL_dKdiag)