[plotting] added predict_kw to plot function

This commit is contained in:
mzwiessele 2015-04-24 11:02:01 +02:00
parent 9c19f8584e
commit 335df2942f
4 changed files with 22 additions and 12 deletions

View file

@ -17,7 +17,7 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
levels=20, samples=0, fignum=None, ax=None, resolution=None,
plot_raw=False,
linecol=Tango.colorsHex['darkBlue'],fillcol=Tango.colorsHex['lightBlue'], Y_metadata=None, data_symbol='kx',
apply_link=False, samples_f=0, plot_uncertain_inputs=True):
apply_link=False, samples_f=0, plot_uncertain_inputs=True, predict_kw=None):
"""
Plot the posterior of the GP.
- In one dimension, the function is plotted with a shaded region identifying two standard deviations.
@ -76,6 +76,9 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
if hasattr(model, 'Z'): Z = model.Z
if predict_kw is None:
predict_kw = {}
#work out what the inputs are for plotting (1D or 2D)
fixed_dims = np.array([i for i,v in fixed_inputs])
free_dims = np.setdiff1d(np.arange(model.input_dim),fixed_dims)
@ -92,7 +95,7 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
#make a prediction on the frame and plot it
if plot_raw:
m, v = model._raw_predict(Xgrid)
m, v = model._raw_predict(Xgrid, **predict_kw)
if apply_link:
lower = model.likelihood.gp_link.transf(m - 2*np.sqrt(v))
upper = model.likelihood.gp_link.transf(m + 2*np.sqrt(v))
@ -106,7 +109,7 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
meta = {'output_index': Xgrid[:,-1:].astype(np.int)}
else:
meta = None
m, v = model.predict(Xgrid, full_cov=False, Y_metadata=meta)
m, v = model.predict(Xgrid, full_cov=False, Y_metadata=meta, **predict_kw)
lower, upper = model.predict_quantiles(Xgrid, Y_metadata=meta)
@ -178,13 +181,13 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
#predict on the frame and plot
if plot_raw:
m, _ = model._raw_predict(Xgrid)
m, _ = model._raw_predict(Xgrid, **predict_kw)
else:
if isinstance(model,GPCoregionalizedRegression) or isinstance(model,SparseGPCoregionalizedRegression):
meta = {'output_index': Xgrid[:,-1:].astype(np.int)}
else:
meta = None
m, v = model.predict(Xgrid, full_cov=False, Y_metadata=meta)
m, v = model.predict(Xgrid, full_cov=False, Y_metadata=meta, **predict_kw)
for d in which_data_ycols:
m_d = m[:,d].reshape(resolution, resolution).T
plots['contour'] = ax.contour(x, y, m_d, levels, vmin=m.min(), vmax=m.max(), cmap=pb.cm.jet)