lots of fixes, including prediction being mean and variance only

This commit is contained in:
James Hensman 2014-03-13 14:42:03 +00:00
parent 365b8ae1e1
commit cc96f5b3d5
13 changed files with 118 additions and 128 deletions

View file

@ -12,7 +12,7 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
which_data_ycols='all', fixed_inputs=[],
levels=20, samples=0, fignum=None, ax=None, resolution=None,
plot_raw=False,
linecol=Tango.colorsHex['darkBlue'],fillcol=Tango.colorsHex['lightBlue']):
linecol=Tango.colorsHex['darkBlue'],fillcol=Tango.colorsHex['lightBlue'], Y_metadata=None):
"""
Plot the posterior of the GP.
- In one dimension, the function is plotted with a shaded region identifying two standard deviations.
@ -84,17 +84,12 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
m, v = model._raw_predict(Xgrid)
lower = m - 2*np.sqrt(v)
upper = m + 2*np.sqrt(v)
Y = Y
else:
if 'noise_index' in model.Y_metadata.keys():
if np.unique(model.Y_metadata['noise_index'][which_data_rows]).size > 1:
print "Data slices choosen have different noise models. Just one will be used."
noise_index = np.repeat(model.Y_metadata['noise_index'][which_data_rows][0], Xgrid.shape[0])[:,None]
m, v, lower, upper = model.predict(Xgrid,full_cov=False,noise_index=noise_index)
else:
noise_index = None
m, v, lower, upper = model.predict(Xgrid,full_cov=False)
Y = Y
m, v = model.predict(Xgrid, full_cov=False, Y_metadata=Y_metadata)
lower, upper = model.predict_quantiles(Xgrid, Y_metadata=Y_metadata)
for d in which_data_ycols:
plots['gpplot'] = gpplot(Xnew, m[:, d], lower[:, d], upper[:, d], ax=ax, edgecol=linecol, fillcol=fillcol)
plots['dataplot'] = ax.plot(X[which_data_rows,free_dims], Y[which_data_rows, d], 'kx', mew=1.5)
@ -144,10 +139,8 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
#predict on the frame and plot
if plot_raw:
m, _ = model._raw_predict(Xgrid)
Y = Y
else:
m, _, _, _ = model.predict(Xgrid)
Y = Y
m, _ = model.predict(Xgrid)
for d in which_data_ycols:
m_d = m[:,d].reshape(resolution, resolution).T
plots['contour'] = ax.contour(x, y, m_d, levels, vmin=m.min(), vmax=m.max(), cmap=pb.cm.jet)