Plotting functions modified

This commit is contained in:
Ricardo 2014-01-28 13:41:16 +00:00
parent 73a7b6f079
commit fd34d8aa05
2 changed files with 13 additions and 18 deletions

View file

@ -1,7 +1,6 @@
# Copyright (c) 2012, GPy authors (see AUTHORS.txt). # Copyright (c) 2012, GPy authors (see AUTHORS.txt).
# Licensed under the BSD 3-clause license (see LICENSE.txt) # Licensed under the BSD 3-clause license (see LICENSE.txt)
import pylab as pb
import datetime as dt import datetime as dt
from scipy import optimize from scipy import optimize
from warnings import warn from warnings import warn
@ -57,13 +56,13 @@ class Optimizer():
raise NotImplementedError, "this needs to be implemented to use the optimizer class" raise NotImplementedError, "this needs to be implemented to use the optimizer class"
def plot(self): def plot(self):
if self.trace == None: """
print "No trace present so I can't plot it. Please check that the optimizer actually supplies a trace." See GPy.plotting.matplot_dep.inference_plots
else: """
pb.figure() assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
pb.plot(self.trace) from ..plotting.matplot_dep import inference_plots
pb.xlabel('Iteration') inference_plots.plot_optimizer(self)
pb.ylabel('f(x)')
def __str__(self): def __str__(self):
diagnostics = "Optimizer: \t\t\t\t %s\n" % self.opt_name diagnostics = "Optimizer: \t\t\t\t %s\n" % self.opt_name

View file

@ -68,16 +68,12 @@ class opt_SGD(Optimizer):
return status return status
def plot_traces(self): def plot_traces(self):
plt.figure() """
plt.subplot(211) See GPy.plotting.matplot_dep.inference_plots
plt.title('Parameters') """
for k in self.param_traces.keys(): assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
plt.plot(self.param_traces[k], label=k) from ..plotting.matplot_dep import inference_plots
plt.legend(loc=0) inference_plots.plot_sgd_traces(self)
plt.subplot(212)
plt.title('Objective function')
plt.plot(self.fopt_trace)
def non_null_samples(self, data): def non_null_samples(self, data):
return (np.isnan(data).sum(axis=1) == 0) return (np.isnan(data).sum(axis=1) == 0)