mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-08 03:22:38 +02:00
[plotting] added predict_kw to plot function
This commit is contained in:
parent
9c19f8584e
commit
335df2942f
4 changed files with 22 additions and 12 deletions
|
|
@ -395,7 +395,7 @@ class GP(Model):
|
||||||
which_data_ycols='all', fixed_inputs=[],
|
which_data_ycols='all', fixed_inputs=[],
|
||||||
levels=20, samples=0, fignum=None, ax=None, resolution=None,
|
levels=20, samples=0, fignum=None, ax=None, resolution=None,
|
||||||
plot_raw=False,
|
plot_raw=False,
|
||||||
linecol=None,fillcol=None, Y_metadata=None, data_symbol='kx'):
|
linecol=None,fillcol=None, Y_metadata=None, data_symbol='kx', predict_kw=None):
|
||||||
"""
|
"""
|
||||||
Plot the posterior of the GP.
|
Plot the posterior of the GP.
|
||||||
- In one dimension, the function is plotted with a shaded region identifying two standard deviations.
|
- In one dimension, the function is plotted with a shaded region identifying two standard deviations.
|
||||||
|
|
@ -444,7 +444,7 @@ class GP(Model):
|
||||||
which_data_ycols, fixed_inputs,
|
which_data_ycols, fixed_inputs,
|
||||||
levels, samples, fignum, ax, resolution,
|
levels, samples, fignum, ax, resolution,
|
||||||
plot_raw=plot_raw, Y_metadata=Y_metadata,
|
plot_raw=plot_raw, Y_metadata=Y_metadata,
|
||||||
data_symbol=data_symbol, **kw)
|
data_symbol=data_symbol, predict_kw=predict_kw, **kw)
|
||||||
|
|
||||||
def input_sensitivity(self, summarize=True):
|
def input_sensitivity(self, summarize=True):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -137,7 +137,7 @@ class SparseGP(GP):
|
||||||
else:
|
else:
|
||||||
Kxx = kern.Kdiag(Xnew)
|
Kxx = kern.Kdiag(Xnew)
|
||||||
if self.posterior.woodbury_inv.ndim == 2:
|
if self.posterior.woodbury_inv.ndim == 2:
|
||||||
var = Kxx - np.sum(np.dot(self.posterior.woodbury_inv.T, Kx) * Kx, 0)
|
var = (Kxx - np.sum(np.dot(self.posterior.woodbury_inv.T, Kx) * Kx, 0))[:,None]
|
||||||
elif self.posterior.woodbury_inv.ndim == 3:
|
elif self.posterior.woodbury_inv.ndim == 3:
|
||||||
var = np.empty((Kxx.shape[0],self.posterior.woodbury_inv.shape[2]))
|
var = np.empty((Kxx.shape[0],self.posterior.woodbury_inv.shape[2]))
|
||||||
for i in range(var.shape[1]):
|
for i in range(var.shape[1]):
|
||||||
|
|
@ -147,9 +147,9 @@ class SparseGP(GP):
|
||||||
if self.mean_function is not None:
|
if self.mean_function is not None:
|
||||||
mu += self.mean_function.f(Xnew)
|
mu += self.mean_function.f(Xnew)
|
||||||
else:
|
else:
|
||||||
psi0_star = self.kern.psi0(self.Z, Xnew)
|
psi0_star = kern.psi0(self.Z, Xnew)
|
||||||
psi1_star = self.kern.psi1(self.Z, Xnew)
|
psi1_star = kern.psi1(self.Z, Xnew)
|
||||||
#psi2_star = self.kern.psi2(self.Z, Xnew) # Only possible if we get NxMxM psi2 out of the code.
|
#psi2_star = kern.psi2(self.Z, Xnew) # Only possible if we get NxMxM psi2 out of the code.
|
||||||
la = self.posterior.woodbury_vector
|
la = self.posterior.woodbury_vector
|
||||||
mu = np.dot(psi1_star, la) # TODO: dimensions?
|
mu = np.dot(psi1_star, la) # TODO: dimensions?
|
||||||
|
|
||||||
|
|
@ -161,7 +161,7 @@ class SparseGP(GP):
|
||||||
|
|
||||||
for i in range(Xnew.shape[0]):
|
for i in range(Xnew.shape[0]):
|
||||||
_mu, _var = Xnew.mean.values[[i]], Xnew.variance.values[[i]]
|
_mu, _var = Xnew.mean.values[[i]], Xnew.variance.values[[i]]
|
||||||
psi2_star = self.kern.psi2(self.Z, NormalPosterior(_mu, _var))
|
psi2_star = kern.psi2(self.Z, NormalPosterior(_mu, _var))
|
||||||
tmp = (psi2_star[:, :] - psi1_star[[i]].T.dot(psi1_star[[i]]))
|
tmp = (psi2_star[:, :] - psi1_star[[i]].T.dot(psi1_star[[i]]))
|
||||||
|
|
||||||
var_ = mdot(la.T, tmp, la)
|
var_ = mdot(la.T, tmp, la)
|
||||||
|
|
|
||||||
|
|
@ -141,6 +141,13 @@ class VerboseOptimization(object):
|
||||||
|
|
||||||
def finish(self, opt):
|
def finish(self, opt):
|
||||||
self.status = opt.status
|
self.status = opt.status
|
||||||
|
if self.verbose and self.ipython_notebook:
|
||||||
|
if 'conv' in self.status.lower():
|
||||||
|
self.progress.bar_style = 'success'
|
||||||
|
elif self.iteration >= self.maxiters:
|
||||||
|
self.progress.bar_style = 'warning'
|
||||||
|
else:
|
||||||
|
self.progress.bar_style = 'danger'
|
||||||
|
|
||||||
def __exit__(self, type, value, traceback):
|
def __exit__(self, type, value, traceback):
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
|
||||||
levels=20, samples=0, fignum=None, ax=None, resolution=None,
|
levels=20, samples=0, fignum=None, ax=None, resolution=None,
|
||||||
plot_raw=False,
|
plot_raw=False,
|
||||||
linecol=Tango.colorsHex['darkBlue'],fillcol=Tango.colorsHex['lightBlue'], Y_metadata=None, data_symbol='kx',
|
linecol=Tango.colorsHex['darkBlue'],fillcol=Tango.colorsHex['lightBlue'], Y_metadata=None, data_symbol='kx',
|
||||||
apply_link=False, samples_f=0, plot_uncertain_inputs=True):
|
apply_link=False, samples_f=0, plot_uncertain_inputs=True, predict_kw=None):
|
||||||
"""
|
"""
|
||||||
Plot the posterior of the GP.
|
Plot the posterior of the GP.
|
||||||
- In one dimension, the function is plotted with a shaded region identifying two standard deviations.
|
- In one dimension, the function is plotted with a shaded region identifying two standard deviations.
|
||||||
|
|
@ -76,6 +76,9 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
|
||||||
|
|
||||||
if hasattr(model, 'Z'): Z = model.Z
|
if hasattr(model, 'Z'): Z = model.Z
|
||||||
|
|
||||||
|
if predict_kw is None:
|
||||||
|
predict_kw = {}
|
||||||
|
|
||||||
#work out what the inputs are for plotting (1D or 2D)
|
#work out what the inputs are for plotting (1D or 2D)
|
||||||
fixed_dims = np.array([i for i,v in fixed_inputs])
|
fixed_dims = np.array([i for i,v in fixed_inputs])
|
||||||
free_dims = np.setdiff1d(np.arange(model.input_dim),fixed_dims)
|
free_dims = np.setdiff1d(np.arange(model.input_dim),fixed_dims)
|
||||||
|
|
@ -92,7 +95,7 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
|
||||||
|
|
||||||
#make a prediction on the frame and plot it
|
#make a prediction on the frame and plot it
|
||||||
if plot_raw:
|
if plot_raw:
|
||||||
m, v = model._raw_predict(Xgrid)
|
m, v = model._raw_predict(Xgrid, **predict_kw)
|
||||||
if apply_link:
|
if apply_link:
|
||||||
lower = model.likelihood.gp_link.transf(m - 2*np.sqrt(v))
|
lower = model.likelihood.gp_link.transf(m - 2*np.sqrt(v))
|
||||||
upper = model.likelihood.gp_link.transf(m + 2*np.sqrt(v))
|
upper = model.likelihood.gp_link.transf(m + 2*np.sqrt(v))
|
||||||
|
|
@ -106,7 +109,7 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
|
||||||
meta = {'output_index': Xgrid[:,-1:].astype(np.int)}
|
meta = {'output_index': Xgrid[:,-1:].astype(np.int)}
|
||||||
else:
|
else:
|
||||||
meta = None
|
meta = None
|
||||||
m, v = model.predict(Xgrid, full_cov=False, Y_metadata=meta)
|
m, v = model.predict(Xgrid, full_cov=False, Y_metadata=meta, **predict_kw)
|
||||||
lower, upper = model.predict_quantiles(Xgrid, Y_metadata=meta)
|
lower, upper = model.predict_quantiles(Xgrid, Y_metadata=meta)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -178,13 +181,13 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
|
||||||
|
|
||||||
#predict on the frame and plot
|
#predict on the frame and plot
|
||||||
if plot_raw:
|
if plot_raw:
|
||||||
m, _ = model._raw_predict(Xgrid)
|
m, _ = model._raw_predict(Xgrid, **predict_kw)
|
||||||
else:
|
else:
|
||||||
if isinstance(model,GPCoregionalizedRegression) or isinstance(model,SparseGPCoregionalizedRegression):
|
if isinstance(model,GPCoregionalizedRegression) or isinstance(model,SparseGPCoregionalizedRegression):
|
||||||
meta = {'output_index': Xgrid[:,-1:].astype(np.int)}
|
meta = {'output_index': Xgrid[:,-1:].astype(np.int)}
|
||||||
else:
|
else:
|
||||||
meta = None
|
meta = None
|
||||||
m, v = model.predict(Xgrid, full_cov=False, Y_metadata=meta)
|
m, v = model.predict(Xgrid, full_cov=False, Y_metadata=meta, **predict_kw)
|
||||||
for d in which_data_ycols:
|
for d in which_data_ycols:
|
||||||
m_d = m[:,d].reshape(resolution, resolution).T
|
m_d = m[:,d].reshape(resolution, resolution).T
|
||||||
plots['contour'] = ax.contour(x, y, m_d, levels, vmin=m.min(), vmax=m.max(), cmap=pb.cm.jet)
|
plots['contour'] = ax.contour(x, y, m_d, levels, vmin=m.min(), vmax=m.max(), cmap=pb.cm.jet)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue