mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-04 01:02:39 +02:00
[plotting] added samples plot
This commit is contained in:
parent
57c4306d92
commit
d7b5f45b71
7 changed files with 118 additions and 31 deletions
|
|
@ -31,7 +31,7 @@
|
|||
import numpy as np
|
||||
from scipy import sparse
|
||||
|
||||
def helper_predict_with_model(self, Xgrid, plot_raw, apply_link, percentiles, which_data_ycols, predict_kw):
|
||||
def helper_predict_with_model(self, Xgrid, plot_raw, apply_link, percentiles, which_data_ycols, predict_kw, samples=0):
|
||||
"""
|
||||
Make the right decisions for prediction with a model
|
||||
based on the standard arguments of plotting.
|
||||
|
|
@ -46,12 +46,12 @@ def helper_predict_with_model(self, Xgrid, plot_raw, apply_link, percentiles, wh
|
|||
if plot_raw:
|
||||
from ...likelihoods import Gaussian
|
||||
from ...likelihoods.link_functions import Identity
|
||||
lik = Gaussian(Identity(), 0) # Make the likelihood not add any noise
|
||||
lik = Gaussian(Identity(), 1e-9) # Make the likelihood not add any noise
|
||||
else:
|
||||
lik = None
|
||||
predict_kw['likelihood'] = lik
|
||||
if 'Y_metadata' not in predict_kw:
|
||||
predict_kw['Y_metadata'] = self.Y_metadata or {}
|
||||
predict_kw['Y_metadata'] = {}
|
||||
if 'output_index' not in predict_kw['Y_metadata']:
|
||||
predict_kw['Y_metadata']['output_index'] = Xgrid[:,-1:].astype(np.int)
|
||||
|
||||
|
|
@ -61,6 +61,12 @@ def helper_predict_with_model(self, Xgrid, plot_raw, apply_link, percentiles, wh
|
|||
percentiles = self.predict_quantiles(Xgrid, quantiles=percentiles, **predict_kw)
|
||||
else: percentiles = []
|
||||
|
||||
if samples > 0:
|
||||
fsamples = self.posterior_samples(Xgrid, full_cov=True, size=samples, **predict_kw)
|
||||
fsamples = fsamples[which_data_ycols] if fsamples.ndim == 3 else fsamples
|
||||
else:
|
||||
fsamples = None
|
||||
|
||||
# Filter out the ycolums which we want to plot:
|
||||
retmu = mu[:, which_data_ycols]
|
||||
percs = [p[:, which_data_ycols] for p in percentiles]
|
||||
|
|
@ -70,8 +76,13 @@ def helper_predict_with_model(self, Xgrid, plot_raw, apply_link, percentiles, wh
|
|||
retmu[:, [i]] = self.likelihood.gp_link.transf(mu[:, [i]])
|
||||
for perc in percs:
|
||||
perc[:, [i]] = self.likelihood.gp_link.transf(perc[:, [i]])
|
||||
|
||||
return retmu, percs
|
||||
if fsamples is not None and fsamples.ndim == 3:
|
||||
for s in range(fsamples.shape[-1]):
|
||||
fsamples[i, :, s] = self.likelihood.gp_link.transf(fsamples[i, :, s])
|
||||
elif fsamples is not None:
|
||||
for s in range(fsamples.shape[-1]):
|
||||
fsamples[:, s] = self.likelihood.gp_link.transf(fsamples[:, s])
|
||||
return retmu, percs, fsamples
|
||||
|
||||
def helper_for_plot_data(self, plot_limits, fixed_inputs, resolution):
|
||||
"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue