From 6e76a96d7724a831af89edbbefd17e3f770a74e3 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Thu, 3 Sep 2015 17:30:16 +0100 Subject: [PATCH] errorbars fixed --- GPy/core/gp.py | 10 ++++++---- GPy/plotting/matplot_dep/models_plots.py | 7 +++---- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/GPy/core/gp.py b/GPy/core/gp.py index d748b085..ba72c912 100644 --- a/GPy/core/gp.py +++ b/GPy/core/gp.py @@ -565,7 +565,7 @@ class GP(Model): which_data_ycols, fixed_inputs, levels, samples, fignum, ax, resolution, plot_raw=plot_raw, Y_metadata=Y_metadata, - data_symbol=data_symbol, predict_kw=predict_kw, + data_symbol=data_symbol, predict_kw=predict_kw, plot_training_data=plot_training_data, **kw) @@ -613,9 +613,9 @@ class GP(Model): fignum, ax, data_symbol, **kw) - def plot_fit_errorbars(self, which_data_rows='all', + def errorbars_trainset(self, which_data_rows='all', which_data_ycols='all', fixed_inputs=[], fignum=None, ax=None, - linecol=None, data_symbol='kx', predict_kw=None, plot_training_data=True): + linecol=None, data_symbol='kx', predict_kw=None, plot_training_data=True,lw=None): """ Plot the posterior error bars corresponding to the training data @@ -640,7 +640,9 @@ class GP(Model): assert "matplotlib" in sys.modules, "matplotlib package has not been imported." from ..plotting.matplot_dep import models_plots kw = {} - return models_plots.plot_fit_errorbars(self, which_data_rows, which_data_ycols, fixed_inputs, + if lw is not None: + kw['lw'] = lw + return models_plots.errorbars_trainset(self, which_data_rows, which_data_ycols, fixed_inputs, fignum, ax, linecol, data_symbol, predict_kw, plot_training_data, **kw) diff --git a/GPy/plotting/matplot_dep/models_plots.py b/GPy/plotting/matplot_dep/models_plots.py index e1ba327e..87ffd740 100644 --- a/GPy/plotting/matplot_dep/models_plots.py +++ b/GPy/plotting/matplot_dep/models_plots.py @@ -330,11 +330,11 @@ def fixed_inputs(model, non_fixed_inputs, fix_routine='median', as_list=True, X_ return X -def plot_fit_errorbars(model, which_data_rows='all', +def errorbars_trainset(model, which_data_rows='all', which_data_ycols='all', fixed_inputs=[], fignum=None, ax=None, linecol='red', data_symbol='kx', - predict_kw=None, plot_training_data=True): + predict_kw=None, plot_training_data=True, **kwargs): """ Plot the posterior error bars corresponding to the training data @@ -386,9 +386,8 @@ def plot_fit_errorbars(model, which_data_rows='all', fmu, fv = model._raw_predict(X, full_cov=False, **predict_kw) lower, upper = model.likelihood.predictive_quantiles(fmu, fv, (2.5, 97.5), Y_metadata=model.Y_metadata) - for d in which_data_ycols: - plots['gperrors'] = gperrors(X, m[:, d], lower[:, d], upper[:, d], edgecol=linecol, ax=ax, fignum=fignum ) + plots['gperrors'] = gperrors(X, m[:, d], lower[:, d], upper[:, d], edgecol=linecol, ax=ax, fignum=fignum, **kwargs ) if plot_training_data: plots['dataplot'] = plot_data(model=model, which_data_rows=which_data_rows, visible_dims=free_dims, data_symbol=data_symbol, mew=1.5, ax=ax, fignum=fignum)