fixed plotting isue with plot_f

This commit is contained in:
James Hensman 2013-10-28 21:41:10 +00:00
parent bd062329a8
commit e5487bff19

View file

@ -99,13 +99,13 @@ class GPBase(Model):
see also: gp_base.plot
"""
kwargs['use_raw_predict'] = True
kwargs['plot_raw'] = True
self.plot(*args, **kwargs)
def plot(self, plot_limits=None, which_data_rows='all',
which_data_ycols='all', which_parts='all', fixed_inputs=[],
levels=20, samples=0, fignum=None, ax=None, resolution=None,
use_raw_predict=False,
plot_raw=False,
linecol=Tango.colorsHex['darkBlue'],fillcol=Tango.colorsHex['lightBlue']):
"""
Plot the posterior of the GP.
@ -170,15 +170,17 @@ class GPBase(Model):
Xgrid[:,i] = v
#make a prediction on the frame and plot it
if use_raw_predict:
if plot_raw:
m, v = self._raw_predict(Xgrid, which_parts=which_parts)
lower = m - 2*np.sqrt(v)
upper = m + 2*np.sqrt(v)
Y = self.likelihood.Y
else:
m, v, lower, upper = self.predict(Xgrid, which_parts=which_parts)
Y = self.likelihood.data
for d in which_data_ycols:
gpplot(Xnew, m[:, d], lower[:, d], upper[:, d], axes=ax, edgecol=linecol, fillcol=fillcol)
ax.plot(Xu[which_data_rows,free_dims], self.likelihood.data[which_data_rows, d], 'kx', mew=1.5)
ax.plot(Xu[which_data_rows,free_dims], Y[which_data_rows, d], 'kx', mew=1.5)
#optionally plot some samples
if samples: #NOTE not tested with fixed_inputs
@ -209,13 +211,14 @@ class GPBase(Model):
#predict on the frame and plot
if use_raw_predict:
m, _ = self._raw_predict(Xgrid, which_parts=which_parts)
Y = self.likelihood.Y
else:
m, _, _, _ = self.predict(Xgrid, which_parts=which_parts)
Y = self.likelihood.data
for d in which_data_ycols:
m_d = m[:,d].reshape(resolution, resolution).T
ax.contour(x, y, m_d, levels, vmin=m.min(), vmax=m.max(), cmap=pb.cm.jet)
Y_d = self.likelihood.Y[which_data_rows,d]
ax.scatter(self.X[which_data_rows, free_dims[0]], self.X[which_data_rows, free_dims[1]], 40, Y_d, cmap=pb.cm.jet, vmin=m.min(), vmax=m.max(), linewidth=0.)
ax.scatter(self.X[which_data_rows, free_dims[0]], self.X[which_data_rows, free_dims[1]], 40, Y[which_data_rows, d], cmap=pb.cm.jet, vmin=m.min(), vmax=m.max(), linewidth=0.)
#set the limits of the plot to some sensible values
ax.set_xlim(xmin[0], xmax[0])