mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-10 04:22:38 +02:00
changes to multiout constructor
This commit is contained in:
parent
a8eb7eb5f7
commit
074529c1c0
3 changed files with 23 additions and 11 deletions
|
|
@ -19,7 +19,9 @@ class WarpedGP(GP):
|
|||
self.warping_function = TanhWarpingFunction_d(warping_terms)
|
||||
self.warping_params = (np.random.randn(self.warping_function.n_terms * 3 + 1,) * 1)
|
||||
|
||||
Y = self._scale_data(Y)
|
||||
self.scale_data = False
|
||||
if self.scale_data:
|
||||
Y = self._scale_data(Y)
|
||||
self.has_uncertain_inputs = False
|
||||
self.Y_untransformed = Y.copy()
|
||||
self.predict_in_warped_space = False
|
||||
|
|
@ -87,11 +89,19 @@ class WarpedGP(GP):
|
|||
def plot_warping(self):
|
||||
self.warping_function.plot(self.warping_params, self.Y_untransformed.min(), self.Y_untransformed.max())
|
||||
|
||||
def _raw_predict(self, *args, **kwargs):
|
||||
mu, var = GP._raw_predict(self, *args, **kwargs)
|
||||
def predict(self, Xnew, which_parts='all', full_cov=False, pred_init=None):
|
||||
# normalize X values
|
||||
Xnew = (Xnew.copy() - self._Xoffset) / self._Xscale
|
||||
mu, var = GP._raw_predict(self, Xnew, full_cov=full_cov, which_parts=which_parts)
|
||||
|
||||
# now push through likelihood
|
||||
mean, var, _025pm, _975pm = self.likelihood.predictive_values(mu, var, full_cov)
|
||||
|
||||
if self.predict_in_warped_space:
|
||||
mu = self.warping_function.f_inv(mu, self.warping_params)
|
||||
mean = self.warping_function.f_inv(mean, self.warping_params, y=pred_init)
|
||||
var = self.warping_function.f_inv(var, self.warping_params)
|
||||
mu = self._unscale_data(mu)
|
||||
return mu, var
|
||||
|
||||
if self.scale_data:
|
||||
mean = self._unscale_data(mean)
|
||||
|
||||
return mean, var, _025pm, _975pm
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue