fix: samples tests and plotting, multioutput

This commit is contained in:
mzwiessele 2018-09-02 19:07:23 +01:00
parent 270b90857c
commit 8446da628b
45 changed files with 15 additions and 17 deletions

View file

@ -578,7 +578,7 @@ class GP(Model):
mag[n] = np.sqrt(np.linalg.det(G[n, :, :]))
return mag
def posterior_samples_f(self,X, size=10, **predict_kwargs):
def posterior_samples_f(self,X, size=10, **predict_kwargs):
"""
Samples the posterior GP at the points X.
@ -589,7 +589,8 @@ class GP(Model):
:returns: set of simulations
:rtype: np.ndarray (Nnew x D x samples)
"""
m, v = self._raw_predict(X, full_cov=True, **predict_kwargs)
predict_kwargs["full_cov"] = True # Always use the full covariance for posterior samples.
m, v = self._raw_predict(X, **predict_kwargs)
if self.normalizer is not None:
m, v = self.normalizer.inverse_mean(m), self.normalizer.inverse_variance(v)

View file

@ -208,10 +208,10 @@ def _plot_samples(self, canvas, helper_data, helper_prediction, projection,
if len(free_dims)==1:
# 1D plotting:
update_not_existing_kwargs(kwargs, pl().defaults.samples_1d) # @UndefinedVariable
plots = [pl().plot(canvas, Xgrid[:, free_dims], samples[:, s], label=label if s==0 else None, **kwargs) for s in range(samples.shape[-1])]
plots = [pl().plot(canvas, Xgrid[:, free_dims], samples[:, :, s], label=label if s==0 else None, **kwargs) for s in range(samples.shape[-1])]
elif len(free_dims)==2 and projection=='3d':
update_not_existing_kwargs(kwargs, pl().defaults.samples_3d) # @UndefinedVariable
plots = [pl().surface(canvas, x, y, samples[:, s].reshape(resolution, resolution), **kwargs) for s in range(samples.shape[-1])]
plots = [pl().surface(canvas, x, y, samples[:, :, s].reshape(resolution, resolution), **kwargs) for s in range(samples.shape[-1])]
else:
pass # Nothing to plot!
return dict(gpmean=plots)

View file

@ -81,8 +81,8 @@ def helper_predict_with_model(self, Xgrid, plot_raw, apply_link, percentiles, wh
else: percentiles = []
if samples > 0:
fsamples = self.posterior_samples(Xgrid, full_cov=True, size=samples, **predict_kw)
fsamples = fsamples[which_data_ycols] if fsamples.ndim == 3 else fsamples
fsamples = self.posterior_samples(Xgrid, size=samples, **predict_kw)
fsamples = fsamples[:, which_data_ycols, :]
else:
fsamples = None
@ -95,12 +95,9 @@ def helper_predict_with_model(self, Xgrid, plot_raw, apply_link, percentiles, wh
retmu[:, [i]] = self.likelihood.gp_link.transf(mu[:, [i]])
for perc in percs:
perc[:, [i]] = self.likelihood.gp_link.transf(perc[:, [i]])
if fsamples is not None and fsamples.ndim == 3:
if fsamples is not None:
for s in range(fsamples.shape[-1]):
fsamples[i, :, s] = self.likelihood.gp_link.transf(fsamples[i, :, s])
elif fsamples is not None:
for s in range(fsamples.shape[-1]):
fsamples[:, s] = self.likelihood.gp_link.transf(fsamples[:, s])
fsamples[:, i, s] = self.likelihood.gp_link.transf(fsamples[:, i, s])
return retmu, percs, fsamples
def helper_for_plot_data(self, X, plot_limits, visible_dims, fixed_inputs, resolution):

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View file

@ -38,7 +38,7 @@ from nose import SkipTest
try:
import matplotlib
matplotlib.use('agg')
# matplotlib.use('agg')
except ImportError:
# matplotlib not installed
from nose import SkipTest
@ -367,13 +367,13 @@ def test_classification():
m = GPy.models.GPClassification(X, Y>Y.mean())
#m.optimize()
_, ax = plt.subplots()
m.plot(plot_raw=False, apply_link=False, ax=ax)
m.plot(plot_raw=False, apply_link=False, ax=ax, samples=3)
m.plot_errorbars_trainset(plot_raw=False, apply_link=False, ax=ax)
_, ax = plt.subplots()
m.plot(plot_raw=True, apply_link=False, ax=ax)
m.plot(plot_raw=True, apply_link=False, ax=ax, samples=3)
m.plot_errorbars_trainset(plot_raw=True, apply_link=False, ax=ax)
_, ax = plt.subplots()
m.plot(plot_raw=True, apply_link=True, ax=ax)
m.plot(plot_raw=True, apply_link=True, ax=ax, samples=3)
m.plot_errorbars_trainset(plot_raw=True, apply_link=True, ax=ax)
for do_test in _image_comparison(baseline_images=['gp_class_{}'.format(sub) for sub in ["likelihood", "raw", 'raw_link']], extensions=extensions):
yield (do_test, )