add trigger update to set_{X,Y,Z}

This commit is contained in:
Zhenwen Dai 2015-03-13 09:47:36 +00:00
parent c6cf0bc121
commit 7f30fef698
2 changed files with 12 additions and 12 deletions

View file

@ -89,7 +89,7 @@ class GP(Model):
self.link_parameter(self.kern) self.link_parameter(self.kern)
self.link_parameter(self.likelihood) self.link_parameter(self.likelihood)
def set_XY(self, X=None, Y=None): def set_XY(self, X=None, Y=None, trigger_update=True):
""" """
Set the input / output data of the model Set the input / output data of the model
This is useful if we wish to change our existing data but maintain the same model This is useful if we wish to change our existing data but maintain the same model
@ -99,7 +99,7 @@ class GP(Model):
:param Y: output observations :param Y: output observations
:type Y: np.ndarray :type Y: np.ndarray
""" """
self.update_model(False) if trigger_update: self.update_model(False)
if Y is not None: if Y is not None:
if self.normalizer is not None: if self.normalizer is not None:
self.normalizer.scale_by(Y) self.normalizer.scale_by(Y)
@ -123,26 +123,26 @@ class GP(Model):
self.link_parameters(self.X) self.link_parameters(self.X)
else: else:
self.X = ObsAr(X) self.X = ObsAr(X)
self.update_model(True) if trigger_update: self.update_model(True)
self._trigger_params_changed() if trigger_update: self._trigger_params_changed()
def set_X(self,X): def set_X(self,X, trigger_update=True):
""" """
Set the input data of the model Set the input data of the model
:param X: input observations :param X: input observations
:type X: np.ndarray :type X: np.ndarray
""" """
self.set_XY(X=X) self.set_XY(X=X, trigger_update=trigger_update)
def set_Y(self,Y): def set_Y(self,Y, trigger_update=True):
""" """
Set the output data of the model Set the output data of the model
:param X: output observations :param X: output observations
:type X: np.ndarray :type X: np.ndarray
""" """
self.set_XY(Y=Y) self.set_XY(Y=Y, trigger_update=trigger_update)
def parameters_changed(self): def parameters_changed(self):
""" """

View file

@ -63,14 +63,14 @@ class SparseGP(GP):
def has_uncertain_inputs(self): def has_uncertain_inputs(self):
return isinstance(self.X, VariationalPosterior) return isinstance(self.X, VariationalPosterior)
def set_Z(self, Z): def set_Z(self, Z, trigger_update=True):
self.update_model(False) if trigger_update: self.update_model(False)
self.unlink_parameter(self.Z) self.unlink_parameter(self.Z)
from ..core import Param from ..core import Param
self.Z = Param('inducing inputs',Z) self.Z = Param('inducing inputs',Z)
self.link_parameter(self.Z, index=0) self.link_parameter(self.Z, index=0)
self.update_model(True) if trigger_update: self.update_model(True)
self._trigger_params_changed() if trigger_update: self._trigger_params_changed()
def parameters_changed(self): def parameters_changed(self):
self.posterior, self._log_marginal_likelihood, self.grad_dict = self.inference_method.inference(self.kern, self.X, self.Z, self.likelihood, self.Y, self.Y_metadata) self.posterior, self._log_marginal_likelihood, self.grad_dict = self.inference_method.inference(self.kern, self.X, self.Z, self.likelihood, self.Y, self.Y_metadata)