diff --git a/GPy/core/parameterization/variational.py b/GPy/core/parameterization/variational.py index 7cc5c99a..36079f2e 100644 --- a/GPy/core/parameterization/variational.py +++ b/GPy/core/parameterization/variational.py @@ -141,7 +141,7 @@ class NormalPosterior(VariationalPosterior): holds the means and variances for a factorizing multivariate normal distribution ''' - def plot(self, *args): + def plot(self, *args, **kwargs): """ Plot latent space X in 1D: @@ -150,8 +150,7 @@ class NormalPosterior(VariationalPosterior): import sys assert "matplotlib" in sys.modules, "matplotlib package has not been imported." from ...plotting.matplot_dep import variational_plots - import matplotlib - return variational_plots.plot(self,*args) + return variational_plots.plot(self, *args, **kwargs) class SpikeAndSlabPosterior(VariationalPosterior): '''