coregionalized 2D plotting fixed

This commit is contained in:
Ricardo 2014-08-06 14:44:56 +01:00
parent 369cc0ba2b
commit 128e894560

View file

@ -150,7 +150,11 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
if plot_raw:
m, _ = model._raw_predict(Xgrid)
else:
m, _ = model.predict(Xgrid)
if isinstance(model,GPCoregionalizedRegression) or isinstance(model,SparseGPCoregionalizedRegression):
meta = {'output_index': Xgrid[:,-1:].astype(np.int)}
else:
meta = None
m, v = model.predict(Xgrid, full_cov=False, Y_metadata=meta)
for d in which_data_ycols:
m_d = m[:,d].reshape(resolution, resolution).T
plots['contour'] = ax.contour(x, y, m_d, levels, vmin=m.min(), vmax=m.max(), cmap=pb.cm.jet)