From fd34d8aa055fb11d829e829cebf89d89ebf42b5a Mon Sep 17 00:00:00 2001 From: Ricardo Date: Tue, 28 Jan 2014 13:41:16 +0000 Subject: [PATCH] Plotting functions modified --- GPy/inference/optimization/optimization.py | 15 +++++++-------- GPy/inference/optimization/sgd.py | 16 ++++++---------- 2 files changed, 13 insertions(+), 18 deletions(-) diff --git a/GPy/inference/optimization/optimization.py b/GPy/inference/optimization/optimization.py index e65b862e..d9be46ce 100644 --- a/GPy/inference/optimization/optimization.py +++ b/GPy/inference/optimization/optimization.py @@ -1,7 +1,6 @@ # Copyright (c) 2012, GPy authors (see AUTHORS.txt). # Licensed under the BSD 3-clause license (see LICENSE.txt) -import pylab as pb import datetime as dt from scipy import optimize from warnings import warn @@ -57,13 +56,13 @@ class Optimizer(): raise NotImplementedError, "this needs to be implemented to use the optimizer class" def plot(self): - if self.trace == None: - print "No trace present so I can't plot it. Please check that the optimizer actually supplies a trace." - else: - pb.figure() - pb.plot(self.trace) - pb.xlabel('Iteration') - pb.ylabel('f(x)') + """ + See GPy.plotting.matplot_dep.inference_plots + """ + assert "matplotlib" in sys.modules, "matplotlib package has not been imported." + from ..plotting.matplot_dep import inference_plots + inference_plots.plot_optimizer(self) + def __str__(self): diagnostics = "Optimizer: \t\t\t\t %s\n" % self.opt_name diff --git a/GPy/inference/optimization/sgd.py b/GPy/inference/optimization/sgd.py index 5cd144e8..3f14dc4b 100644 --- a/GPy/inference/optimization/sgd.py +++ b/GPy/inference/optimization/sgd.py @@ -68,16 +68,12 @@ class opt_SGD(Optimizer): return status def plot_traces(self): - plt.figure() - plt.subplot(211) - plt.title('Parameters') - for k in self.param_traces.keys(): - plt.plot(self.param_traces[k], label=k) - plt.legend(loc=0) - plt.subplot(212) - plt.title('Objective function') - plt.plot(self.fopt_trace) - + """ + See GPy.plotting.matplot_dep.inference_plots + """ + assert "matplotlib" in sys.modules, "matplotlib package has not been imported." + from ..plotting.matplot_dep import inference_plots + inference_plots.plot_sgd_traces(self) def non_null_samples(self, data): return (np.isnan(data).sum(axis=1) == 0)