mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-02 14:45:15 +02:00
Add mean function functionality to dtc inference method
This commit is contained in:
parent
754c67f71d
commit
a24a9b3edc
5 changed files with 45 additions and 18 deletions
|
|
@ -74,11 +74,16 @@ class SparseGP(GP):
|
|||
if trigger_update: self.update_model(True)
|
||||
|
||||
def parameters_changed(self):
|
||||
self.posterior, self._log_marginal_likelihood, self.grad_dict = self.inference_method.inference(self.kern, self.X, self.Z, self.likelihood, self.Y, self.Y_metadata)
|
||||
self.posterior, self._log_marginal_likelihood, self.grad_dict = \
|
||||
self.inference_method.inference(self.kern, self.X, self.Z, self.likelihood,
|
||||
self.Y, Y_metadata=self.Y_metadata,
|
||||
mean_function=self.mean_function)
|
||||
self._update_gradients()
|
||||
|
||||
def _update_gradients(self):
|
||||
self.likelihood.update_gradients(self.grad_dict['dL_dthetaL'])
|
||||
if self.mean_function is not None:
|
||||
self.mean_function.update_gradients(self.grad_dict['dL_dm'], self.X)
|
||||
|
||||
if isinstance(self.X, VariationalPosterior):
|
||||
#gradients wrt kernel
|
||||
|
|
@ -112,4 +117,3 @@ class SparseGP(GP):
|
|||
self.Z.gradient = self.kern.gradients_X(self.grad_dict['dL_dKmm'], self.Z)
|
||||
self.Z.gradient += self.kern.gradients_X(self.grad_dict['dL_dKnm'].T, self.Z, self.X)
|
||||
self._Zgrad = self.Z.gradient.copy()
|
||||
|
||||
|
|
|
|||
|
|
@ -34,7 +34,9 @@ class SparseGP_MPI(SparseGP):
|
|||
|
||||
"""
|
||||
|
||||
def __init__(self, X, Y, Z, kernel, likelihood, variational_prior=None, inference_method=None, name='sparse gp', Y_metadata=None, mpi_comm=None, normalizer=False):
|
||||
def __init__(self, X, Y, Z, kernel, likelihood, variational_prior=None,
|
||||
mean_function=None, inference_method=None, name='sparse gp',
|
||||
Y_metadata=None, mpi_comm=None, normalizer=False):
|
||||
self._IN_OPTIMIZATION_ = False
|
||||
if mpi_comm != None:
|
||||
if inference_method is None:
|
||||
|
|
@ -42,12 +44,12 @@ class SparseGP_MPI(SparseGP):
|
|||
else:
|
||||
assert isinstance(inference_method, VarDTC_minibatch), 'inference_method has to support MPI!'
|
||||
|
||||
super(SparseGP_MPI, self).__init__(X, Y, Z, kernel, likelihood, inference_method=inference_method, name=name, Y_metadata=Y_metadata, normalizer=normalizer)
|
||||
super(SparseGP_MPI, self).__init__(X, Y, Z, kernel, likelihood, inference_method=inference_method, mean_function=mean_function, name=name, Y_metadata=Y_metadata, normalizer=normalizer)
|
||||
self.update_model(False)
|
||||
|
||||
|
||||
if variational_prior is not None:
|
||||
self.link_parameter(variational_prior)
|
||||
|
||||
|
||||
self.mpi_comm = mpi_comm
|
||||
# Manage the data (Y) division
|
||||
if mpi_comm != None:
|
||||
|
|
@ -118,4 +120,3 @@ class SparseGP_MPI(SparseGP):
|
|||
update_gradients(self, mpi_comm=self.mpi_comm)
|
||||
else:
|
||||
super(SparseGP_MPI,self).parameters_changed()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue