grad dict is property of self

This commit is contained in:
Max Zwiessele 2014-03-13 09:07:56 +00:00
parent eb8b2c8b47
commit 7adf5217f2
2 changed files with 4 additions and 4 deletions

View file

@ -67,9 +67,9 @@ class GP(Model):
self.add_parameter(self.likelihood) self.add_parameter(self.likelihood)
def parameters_changed(self): def parameters_changed(self):
self.posterior, self._log_marginal_likelihood, grad_dict = self.inference_method.inference(self.kern, self.X, self.likelihood, self.Y, Y_metadata=self.Y_metadata) self.posterior, self._log_marginal_likelihood, self.grad_dict = self.inference_method.inference(self.kern, self.X, self.likelihood, self.Y, Y_metadata=self.Y_metadata)
self.likelihood.update_gradients(np.diag(grad_dict['dL_dK'])) self.likelihood.update_gradients(np.diag(self.grad_dict['dL_dK']))
self.kern.update_gradients_full(grad_dict['dL_dK'], self.X) self.kern.update_gradients_full(self.grad_dict['dL_dK'], self.X)
def log_likelihood(self): def log_likelihood(self):
return self._log_marginal_likelihood return self._log_marginal_likelihood

View file

@ -41,7 +41,7 @@ class GPLVM(GP):
def parameters_changed(self): def parameters_changed(self):
super(GPLVM, self).parameters_changed() super(GPLVM, self).parameters_changed()
self.X.gradient = self.kern.gradients_X(self.dL_dK, self.X, None) self.X.gradient = self.kern.gradients_X(self.grad_dict['dL_dK'], self.X, None)
def _getstate(self): def _getstate(self):
return GP._getstate(self) return GP._getstate(self)