From 15901a48d43566eef9995eeeed7ee9df6eb3c146 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Mon, 17 Mar 2014 18:58:31 +0000 Subject: [PATCH] Changes to allow compatibility with mixed noise likelihoods --- GPy/core/sparse_gp.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/GPy/core/sparse_gp.py b/GPy/core/sparse_gp.py index d137ceff..a0b09564 100644 --- a/GPy/core/sparse_gp.py +++ b/GPy/core/sparse_gp.py @@ -31,7 +31,7 @@ class SparseGP(GP): """ - def __init__(self, X, Y, Z, kernel, likelihood, inference_method=None, name='sparse gp'): + def __init__(self, X, Y, Z, kernel, likelihood, inference_method=None, name='sparse gp', Y_metadata=None): #pick a sensible inference method if inference_method is None: @@ -45,7 +45,7 @@ class SparseGP(GP): self.Z = Param('inducing inputs', Z) self.num_inducing = Z.shape[0] - GP.__init__(self, X, Y, kernel, likelihood, inference_method=inference_method, name=name) + GP.__init__(self, X, Y, kernel, likelihood, inference_method=inference_method, name=name, Y_metadata=Y_metadata) self.add_parameter(self.Z, index=0) @@ -53,7 +53,7 @@ class SparseGP(GP): return isinstance(self.X, VariationalPosterior) def parameters_changed(self): - self.posterior, self._log_marginal_likelihood, self.grad_dict = self.inference_method.inference(self.kern, self.X, self.Z, self.likelihood, self.Y) + self.posterior, self._log_marginal_likelihood, self.grad_dict = self.inference_method.inference(self.kern, self.X, self.Z, self.likelihood, self.Y, self.Y_metadata) self.likelihood.update_gradients(self.grad_dict['dL_dthetaL']) if isinstance(self.X, VariationalPosterior): #gradients wrt kernel @@ -75,7 +75,6 @@ class SparseGP(GP): target += self.kern.gradient self.kern.update_gradients_full(self.grad_dict['dL_dKmm'], self.Z, None) self.kern.gradient += target - #gradients wrt Z self.Z.gradient[:,self.kern.active_dims] = self.kern.gradients_X(self.grad_dict['dL_dKmm'], self.Z) self.Z.gradient[:,self.kern.active_dims] += self.kern.gradients_X(self.grad_dict['dL_dKnm'].T, self.Z, self.X)