[classification] sparse gp classification and dtc update

This commit is contained in:
Max Zwiessele 2015-09-11 15:08:30 +01:00
parent 4ea5ebaa68
commit 1d354f5cce
14 changed files with 208 additions and 369 deletions

View file

@ -40,7 +40,7 @@ class SparseGP(GP):
"""
def __init__(self, X, Y, Z, kernel, likelihood, mean_function=None, inference_method=None,
def __init__(self, X, Y, Z, kernel, likelihood, mean_function=None, X_variance=None, inference_method=None,
name='sparse gp', Y_metadata=None, normalizer=False):
#pick a sensible inference method
if inference_method is None:
@ -73,11 +73,12 @@ class SparseGP(GP):
self.Z = Param('inducing inputs',Z)
self.link_parameter(self.Z, index=0)
if trigger_update: self.update_model(True)
if trigger_update: self._trigger_params_changed()
def parameters_changed(self):
self.posterior, self._log_marginal_likelihood, self.grad_dict = self.inference_method.inference(self.kern, self.X, self.Z, self.likelihood, self.Y, self.Y_metadata)
self._update_gradients()
def _update_gradients(self):
self.likelihood.update_gradients(self.grad_dict['dL_dthetaL'])
if isinstance(self.X, VariationalPosterior):