diff --git a/GPy/core/gp.py b/GPy/core/gp.py index ab725897..214c2324 100644 --- a/GPy/core/gp.py +++ b/GPy/core/gp.py @@ -26,7 +26,7 @@ class GP(Model): """ - def __init__(self, X, Y, kernel, likelihood, inference_method=None, name='gp'): + def __init__(self, X, Y, kernel, likelihood, inference_method=None, Y_metadata=None, name='gp'): super(GP, self).__init__(name) assert X.ndim == 2 @@ -38,6 +38,12 @@ class GP(Model): assert Y.shape[0] == self.num_data _, self.output_dim = self.Y.shape + if Y_metadata is not None: + assert Y_metadata.shape == self.Y.shape + self.Y_metadata = ObservableArray(Y_metadata) + else: + self.Y_metadata = None + assert isinstance(kernel, kern.kern) self.kern = kernel @@ -58,7 +64,7 @@ class GP(Model): self.parameters_changed() def parameters_changed(self): - self.posterior, self._log_marginal_likelihood, grad_dict = self.inference_method.inference(self.kern, self.X, self.likelihood, self.Y) + self.posterior, self._log_marginal_likelihood, grad_dict = self.inference_method.inference(self.kern, self.X, self.likelihood, self.Y, Y_metadata=self.Y_metadata) self._dL_dK = grad_dict['dL_dK'] def log_likelihood(self):