diff --git a/GPy/plotting/matplot_dep/models_plots.py b/GPy/plotting/matplot_dep/models_plots.py index 16568c7d..77c42825 100644 --- a/GPy/plotting/matplot_dep/models_plots.py +++ b/GPy/plotting/matplot_dep/models_plots.py @@ -11,6 +11,7 @@ from base_plots import gpplot, x_frame1D, x_frame2D from ...models.gp_coregionalized_regression import GPCoregionalizedRegression from ...models.sparse_gp_coregionalized_regression import SparseGPCoregionalizedRegression from scipy import sparse +from ...core.parameterization.variational import VariationalPosterior def plot_fit(model, plot_limits=None, which_data_rows='all', which_data_ycols='all', fixed_inputs=[], @@ -78,7 +79,7 @@ def plot_fit(model, plot_limits=None, which_data_rows='all', if predict_kw is None: predict_kw = {} - + #work out what the inputs are for plotting (1D or 2D) fixed_dims = np.array([i for i,v in fixed_inputs]) free_dims = np.setdiff1d(np.arange(model.input_dim),fixed_dims) @@ -219,7 +220,7 @@ def plot_fit_f(model, *args, **kwargs): kwargs['plot_raw'] = True plot_fit(model,*args, **kwargs) -def fixed_inputs(model, non_fixed_inputs, fix_routine='median', as_list=True): +def fixed_inputs(model, non_fixed_inputs, fix_routine='median', as_list=True, X_all=False): """ Convenience function for returning back fixed_inputs where the other inputs are fixed using fix_routine @@ -235,8 +236,13 @@ def fixed_inputs(model, non_fixed_inputs, fix_routine='median', as_list=True): f_inputs = [] if hasattr(model, 'has_uncertain_inputs') and model.has_uncertain_inputs(): X = model.X.mean.values.copy() - else: + elif isinstance(model.X, VariationalPosterior): X = model.X.values.copy() + else: + if X_all: + X = model.X_all.copy() + else: + X = model.X.copy() for i in range(X.shape[1]): if i not in non_fixed_inputs: if fix_routine == 'mean':