Merge branch 'wgps_improvements' of https://github.com/beckdaniel/GPy into beckdaniel-wgps_improvements

This commit is contained in:
mzwiessele 2016-08-03 13:27:52 +01:00
commit 9be51df4f1
5 changed files with 261 additions and 242 deletions

View file

@ -31,6 +31,7 @@
import numpy as np
from scipy import sparse
import itertools
from ...models import WarpedGP
def in_ipynb():
try:
@ -73,6 +74,9 @@ def helper_predict_with_model(self, Xgrid, plot_raw, apply_link, percentiles, wh
if 'output_index' not in predict_kw['Y_metadata']:
predict_kw['Y_metadata']['output_index'] = Xgrid[:,-1:].astype(np.int)
if isinstance(self, WarpedGP) and self.predict_in_warped_space:
predict_kw['median'] = True
mu, _ = self.predict(Xgrid, **predict_kw)
if percentiles is not None:
@ -295,6 +299,8 @@ def get_x_y_var(model):
Y = model.Y.values
except AttributeError:
Y = model.Y
if isinstance(model, WarpedGP) and model.predict_in_warped_space:
Y = model.Y_untransformed
if sparse.issparse(Y): Y = Y.todense().view(np.ndarray)
return X, X_variance, Y
@ -381,4 +387,4 @@ def x_frame2D(X,plot_limits=None,resolution=None):
resolution = resolution or 50
xx, yy = np.mgrid[xmin[0]:xmax[0]:1j*resolution,xmin[1]:xmax[1]:1j*resolution]
Xnew = np.c_[xx.flat, yy.flat]
return Xnew, xx, yy, xmin, xmax
return Xnew, xx, yy, xmin, xmax