mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-24 14:15:14 +02:00
Plotting functions modified
This commit is contained in:
parent
73a7b6f079
commit
fd34d8aa05
2 changed files with 13 additions and 18 deletions
|
|
@ -1,7 +1,6 @@
|
||||||
# Copyright (c) 2012, GPy authors (see AUTHORS.txt).
|
# Copyright (c) 2012, GPy authors (see AUTHORS.txt).
|
||||||
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
||||||
|
|
||||||
import pylab as pb
|
|
||||||
import datetime as dt
|
import datetime as dt
|
||||||
from scipy import optimize
|
from scipy import optimize
|
||||||
from warnings import warn
|
from warnings import warn
|
||||||
|
|
@ -57,13 +56,13 @@ class Optimizer():
|
||||||
raise NotImplementedError, "this needs to be implemented to use the optimizer class"
|
raise NotImplementedError, "this needs to be implemented to use the optimizer class"
|
||||||
|
|
||||||
def plot(self):
|
def plot(self):
|
||||||
if self.trace == None:
|
"""
|
||||||
print "No trace present so I can't plot it. Please check that the optimizer actually supplies a trace."
|
See GPy.plotting.matplot_dep.inference_plots
|
||||||
else:
|
"""
|
||||||
pb.figure()
|
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
|
||||||
pb.plot(self.trace)
|
from ..plotting.matplot_dep import inference_plots
|
||||||
pb.xlabel('Iteration')
|
inference_plots.plot_optimizer(self)
|
||||||
pb.ylabel('f(x)')
|
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
diagnostics = "Optimizer: \t\t\t\t %s\n" % self.opt_name
|
diagnostics = "Optimizer: \t\t\t\t %s\n" % self.opt_name
|
||||||
|
|
|
||||||
|
|
@ -68,16 +68,12 @@ class opt_SGD(Optimizer):
|
||||||
return status
|
return status
|
||||||
|
|
||||||
def plot_traces(self):
|
def plot_traces(self):
|
||||||
plt.figure()
|
"""
|
||||||
plt.subplot(211)
|
See GPy.plotting.matplot_dep.inference_plots
|
||||||
plt.title('Parameters')
|
"""
|
||||||
for k in self.param_traces.keys():
|
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
|
||||||
plt.plot(self.param_traces[k], label=k)
|
from ..plotting.matplot_dep import inference_plots
|
||||||
plt.legend(loc=0)
|
inference_plots.plot_sgd_traces(self)
|
||||||
plt.subplot(212)
|
|
||||||
plt.title('Objective function')
|
|
||||||
plt.plot(self.fopt_trace)
|
|
||||||
|
|
||||||
|
|
||||||
def non_null_samples(self, data):
|
def non_null_samples(self, data):
|
||||||
return (np.isnan(data).sum(axis=1) == 0)
|
return (np.isnan(data).sum(axis=1) == 0)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue