changes to multiout constructor

This commit is contained in:
Nicolo Fusi 2013-07-08 14:27:32 +01:00
parent a8eb7eb5f7
commit 074529c1c0
3 changed files with 23 additions and 11 deletions

View file

@ -19,7 +19,9 @@ class WarpedGP(GP):
self.warping_function = TanhWarpingFunction_d(warping_terms)
self.warping_params = (np.random.randn(self.warping_function.n_terms * 3 + 1,) * 1)
Y = self._scale_data(Y)
self.scale_data = False
if self.scale_data:
Y = self._scale_data(Y)
self.has_uncertain_inputs = False
self.Y_untransformed = Y.copy()
self.predict_in_warped_space = False
@ -87,11 +89,19 @@ class WarpedGP(GP):
def plot_warping(self):
self.warping_function.plot(self.warping_params, self.Y_untransformed.min(), self.Y_untransformed.max())
def _raw_predict(self, *args, **kwargs):
mu, var = GP._raw_predict(self, *args, **kwargs)
def predict(self, Xnew, which_parts='all', full_cov=False, pred_init=None):
# normalize X values
Xnew = (Xnew.copy() - self._Xoffset) / self._Xscale
mu, var = GP._raw_predict(self, Xnew, full_cov=full_cov, which_parts=which_parts)
# now push through likelihood
mean, var, _025pm, _975pm = self.likelihood.predictive_values(mu, var, full_cov)
if self.predict_in_warped_space:
mu = self.warping_function.f_inv(mu, self.warping_params)
mean = self.warping_function.f_inv(mean, self.warping_params, y=pred_init)
var = self.warping_function.f_inv(var, self.warping_params)
mu = self._unscale_data(mu)
return mu, var
if self.scale_data:
mean = self._unscale_data(mean)
return mean, var, _025pm, _975pm