diff --git a/GPy/core/gp.py b/GPy/core/gp.py index 38a7bb3d..732db7e2 100644 --- a/GPy/core/gp.py +++ b/GPy/core/gp.py @@ -34,7 +34,7 @@ class GP(Model): """ - def __init__(self, X, Y, kernel, likelihood, inference_method=None, name='gp', Y_metadata=None, normalizer=False): + def __init__(self, X, Y, kernel, likelihood, mean_function=None, inference_method=None, name='gp', Y_metadata=None, normalizer=False): super(GP, self).__init__(name) assert X.ndim == 2 @@ -75,6 +75,15 @@ class GP(Model): assert isinstance(likelihood, likelihoods.Likelihood) self.likelihood = likelihood + #handle the mean function + self.mean_function = mean_function + if mean_function is not None: + assert isinstance(self.mean_function, Mapping) + assert mean_function.input_dim == self.input_dim + assert mean_function.output_dim == self.output_dim + self.add_parameter(mean_function) + + #find a sensible inference method logger.info("initializing inference method") if inference_method is None: @@ -153,7 +162,7 @@ class GP(Model): This method is not designed to be called manually, the framework is set up to automatically call this method upon changes to parameters, if you call this method yourself, there may be unexpected consequences. """ - self.posterior, self._log_marginal_likelihood, self.grad_dict = self.inference_method.inference(self.kern, self.X, self.likelihood, self.Y_normalized, self.Y_metadata) + self.posterior, self._log_marginal_likelihood, self.grad_dict = self.inference_method.inference(self.kern, self.X, self.likelihood, self.Y_normalized, self.mean_function, self.Y_metadata) self.likelihood.update_gradients(self.grad_dict['dL_dthetaL']) self.kern.update_gradients_full(self.grad_dict['dL_dK'], self.X)