diff --git a/GPy/plotting/gpy_plot/gp_plots.py b/GPy/plotting/gpy_plot/gp_plots.py index 702aeb7b..1283b6a7 100644 --- a/GPy/plotting/gpy_plot/gp_plots.py +++ b/GPy/plotting/gpy_plot/gp_plots.py @@ -81,6 +81,7 @@ def _plot_mean(self, canvas, helper_data, helper_prediction, **kwargs): _, free_dims, Xgrid, x, y, _, _, resolution = helper_data + plots = dict() if len(free_dims)<=2: mu, _, _ = helper_prediction if len(free_dims)==1: @@ -88,12 +89,12 @@ def _plot_mean(self, canvas, helper_data, helper_prediction, update_not_existing_kwargs(kwargs, pl().defaults.meanplot_1d) # @UndefinedVariable plots = dict(gpmean=[pl().plot(canvas, Xgrid[:, free_dims], mu, label=label, **kwargs)]) else: - if projection == '2d': + if projection.lower() in '2d': update_not_existing_kwargs(kwargs, pl().defaults.meanplot_2d) # @UndefinedVariable plots = dict(gpmean=[pl().contour(canvas, x[:,0], y[0,:], mu.reshape(resolution, resolution).T, levels=levels, label=label, **kwargs)]) - elif projection == '3d': + elif projection.lower() in '3d': update_not_existing_kwargs(kwargs, pl().defaults.meanplot_3d) # @UndefinedVariable plots = dict(gpmean=[pl().surface(canvas, x, y, mu.reshape(resolution, resolution),