Plotting functions modified

This commit is contained in:
Ricardo 2014-01-28 13:40:48 +00:00
parent 67658945b1
commit 73a7b6f079

View file

@ -2,7 +2,6 @@
# Licensed under the BSD 3-clause license (see LICENSE.txt) # Licensed under the BSD 3-clause license (see LICENSE.txt)
import numpy as np import numpy as np
import pylab as pb
from ..util.linalg import pdinv, mdot, tdot, dpotrs, dtrtrs, jitchol, backsub_both_sides from ..util.linalg import pdinv, mdot, tdot, dpotrs, dtrtrs, jitchol, backsub_both_sides
from gp import GP from gp import GP
import time import time
@ -480,38 +479,19 @@ class SVIGP(GP):
return self.q_u_canonical_flat return self.q_u_canonical_flat
def plot(self, ax=None, fignum=None, Z_height=None, **kwargs): def plot(self, *args, **kwargs):
"""
See GPy.plotting.matplot_dep.svgi_plots.plot
"""
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
from ..plotting.matplot_dep import svgi_plots
svgi_plots.plot(self,*args,**kwargs)
if ax is None:
fig = pb.figure(num=fignum)
ax = fig.add_subplot(111)
#horrible hack here:
data = self.likelihood.data.copy()
self.likelihood.data = self.Y
GP.plot(self, ax=ax, **kwargs)
self.likelihood.data = data
Zu = self.Z * self._Xscale + self._Xoffset
if self.input_dim==1:
ax.plot(self.X_batch, self.likelihood.data, 'gx',mew=2)
if Z_height is None:
Z_height = ax.get_ylim()[0]
ax.plot(Zu, np.zeros_like(Zu) + Z_height, 'r|', mew=1.5, markersize=12)
if self.input_dim==2:
ax.scatter(self.X[:,0], self.X[:,1], 20., self.Y[:,0], linewidth=0, cmap=pb.cm.jet) # @UndefinedVariable
ax.plot(Zu[:,0], Zu[:,1], 'w^')
def plot_traces(self): def plot_traces(self):
pb.figure() """
t = np.array(self._param_trace) See GPy.plotting.matplot_dep.svgi_plots.plot_traces
pb.subplot(2,1,1) """
for l,ti in zip(self._get_param_names(),t.T): assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
if not l[:3]=='iip': from ..plotting.matplot_dep import svgi_plots
pb.plot(ti,label=l) svgi_plots.plot_traces(self)
pb.legend(loc=0)
pb.subplot(2,1,2)
pb.plot(np.asarray(self._ll_trace),label='stochastic likelihood')
pb.legend(loc=0)