From 3aa563d5bad488d3e577ca3c457d2519e0255f21 Mon Sep 17 00:00:00 2001 From: beckdaniel Date: Tue, 4 Aug 2015 14:22:46 +0100 Subject: [PATCH] first try in implementing warped mean --- GPy/models/warped_gp.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/GPy/models/warped_gp.py b/GPy/models/warped_gp.py index eec37e48..540b6cb2 100644 --- a/GPy/models/warped_gp.py +++ b/GPy/models/warped_gp.py @@ -25,7 +25,7 @@ class WarpedGP(GP): Y = self._scale_data(Y) #self.has_uncertain_inputs = False self.Y_untransformed = Y.copy() - self.predict_in_warped_space = False + self.predict_in_warped_space = True likelihood = likelihoods.Gaussian() GP.__init__(self, X, self.transform_data(), likelihood=likelihood, kernel=kernel) @@ -69,7 +69,19 @@ class WarpedGP(GP): def plot_warping(self): self.warping_function.plot(self.Y_untransformed.min(), self.Y_untransformed.max()) - def predict(self, Xnew, which_parts='all', pred_init=None, full_cov=False, Y_metadata=None): + def _get_warped_mean(self, mean, var, pred_init=None, deg_gauss_hermite=100): + """ + Calculate the warped mean by using Gauss-Hermite quadrature. + """ + gh_samples, gh_weights = np.polynomial.hermite.hermgauss(deg_gauss_hermite) + gh_samples = gh_samples[:,None] + gh_weights = gh_weights[None,:] + arg1 = gh_samples.dot(var.T) * np.sqrt(2) + arg2 = np.ones(shape=gh_samples.shape).dot(mean.T) + return gh_weights.dot(self.warping_function.f_inv(arg1 + arg2, y=pred_init)) / np.sqrt(np.pi) + + def predict(self, Xnew, which_parts='all', pred_init=None, full_cov=False, Y_metadata=None, + median=False, deg_gauss_hermite=100): # normalize X values # Xnew = (Xnew.copy() - self._Xoffset) / self._Xscale mu, var = GP._raw_predict(self, Xnew) @@ -78,13 +90,17 @@ class WarpedGP(GP): mean, var = self.likelihood.predictive_values(mu, var) if self.predict_in_warped_space: - mean = self.warping_function.f_inv(mean, y=pred_init) + if median: + pred = self.warping_function.f_inv(mean, y=pred_init) + else: + pred = self._get_warped_mean(mean, var, pred_init=pred_init, + deg_gauss_hermite=deg_gauss_hermite).T var = self.warping_function.f_inv(var) if self.scale_data: - mean = self._unscale_data(mean) + pred = self._unscale_data(pred) - return mean, var + return pred, var def predict_quantiles(self, X, quantiles=(2.5, 97.5), Y_metadata=None): """