added parseing of mean func to gp.py

This commit is contained in:
James Hensman 2015-03-23 14:56:19 +00:00
parent fa801bf46c
commit 2d312099c0

View file

@ -34,7 +34,7 @@ class GP(Model):
"""
def __init__(self, X, Y, kernel, likelihood, inference_method=None, name='gp', Y_metadata=None, normalizer=False):
def __init__(self, X, Y, kernel, likelihood, mean_function=None, inference_method=None, name='gp', Y_metadata=None, normalizer=False):
super(GP, self).__init__(name)
assert X.ndim == 2
@ -75,6 +75,15 @@ class GP(Model):
assert isinstance(likelihood, likelihoods.Likelihood)
self.likelihood = likelihood
#handle the mean function
self.mean_function = mean_function
if mean_function is not None:
assert isinstance(self.mean_function, Mapping)
assert mean_function.input_dim == self.input_dim
assert mean_function.output_dim == self.output_dim
self.add_parameter(mean_function)
#find a sensible inference method
logger.info("initializing inference method")
if inference_method is None:
@ -153,7 +162,7 @@ class GP(Model):
This method is not designed to be called manually, the framework is set up to automatically call this method upon changes to parameters, if you call
this method yourself, there may be unexpected consequences.
"""
self.posterior, self._log_marginal_likelihood, self.grad_dict = self.inference_method.inference(self.kern, self.X, self.likelihood, self.Y_normalized, self.Y_metadata)
self.posterior, self._log_marginal_likelihood, self.grad_dict = self.inference_method.inference(self.kern, self.X, self.likelihood, self.Y_normalized, self.mean_function, self.Y_metadata)
self.likelihood.update_gradients(self.grad_dict['dL_dthetaL'])
self.kern.update_gradients_full(self.grad_dict['dL_dK'], self.X)