first try in implementing warped mean

This commit is contained in:
beckdaniel 2015-08-04 14:22:46 +01:00
parent cf4c5af487
commit 3aa563d5ba

View file

@ -25,7 +25,7 @@ class WarpedGP(GP):
Y = self._scale_data(Y)
#self.has_uncertain_inputs = False
self.Y_untransformed = Y.copy()
self.predict_in_warped_space = False
self.predict_in_warped_space = True
likelihood = likelihoods.Gaussian()
GP.__init__(self, X, self.transform_data(), likelihood=likelihood, kernel=kernel)
@ -69,7 +69,19 @@ class WarpedGP(GP):
def plot_warping(self):
self.warping_function.plot(self.Y_untransformed.min(), self.Y_untransformed.max())
def predict(self, Xnew, which_parts='all', pred_init=None, full_cov=False, Y_metadata=None):
def _get_warped_mean(self, mean, var, pred_init=None, deg_gauss_hermite=100):
"""
Calculate the warped mean by using Gauss-Hermite quadrature.
"""
gh_samples, gh_weights = np.polynomial.hermite.hermgauss(deg_gauss_hermite)
gh_samples = gh_samples[:,None]
gh_weights = gh_weights[None,:]
arg1 = gh_samples.dot(var.T) * np.sqrt(2)
arg2 = np.ones(shape=gh_samples.shape).dot(mean.T)
return gh_weights.dot(self.warping_function.f_inv(arg1 + arg2, y=pred_init)) / np.sqrt(np.pi)
def predict(self, Xnew, which_parts='all', pred_init=None, full_cov=False, Y_metadata=None,
median=False, deg_gauss_hermite=100):
# normalize X values
# Xnew = (Xnew.copy() - self._Xoffset) / self._Xscale
mu, var = GP._raw_predict(self, Xnew)
@ -78,13 +90,17 @@ class WarpedGP(GP):
mean, var = self.likelihood.predictive_values(mu, var)
if self.predict_in_warped_space:
mean = self.warping_function.f_inv(mean, y=pred_init)
if median:
pred = self.warping_function.f_inv(mean, y=pred_init)
else:
pred = self._get_warped_mean(mean, var, pred_init=pred_init,
deg_gauss_hermite=deg_gauss_hermite).T
var = self.warping_function.f_inv(var)
if self.scale_data:
mean = self._unscale_data(mean)
pred = self._unscale_data(pred)
return mean, var
return pred, var
def predict_quantiles(self, X, quantiles=(2.5, 97.5), Y_metadata=None):
"""