mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-21 14:05:14 +02:00
Plotting functions modified
This commit is contained in:
parent
67658945b1
commit
73a7b6f079
1 changed files with 13 additions and 33 deletions
|
|
@ -2,7 +2,6 @@
|
||||||
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pylab as pb
|
|
||||||
from ..util.linalg import pdinv, mdot, tdot, dpotrs, dtrtrs, jitchol, backsub_both_sides
|
from ..util.linalg import pdinv, mdot, tdot, dpotrs, dtrtrs, jitchol, backsub_both_sides
|
||||||
from gp import GP
|
from gp import GP
|
||||||
import time
|
import time
|
||||||
|
|
@ -480,38 +479,19 @@ class SVIGP(GP):
|
||||||
return self.q_u_canonical_flat
|
return self.q_u_canonical_flat
|
||||||
|
|
||||||
|
|
||||||
def plot(self, ax=None, fignum=None, Z_height=None, **kwargs):
|
def plot(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
See GPy.plotting.matplot_dep.svgi_plots.plot
|
||||||
|
"""
|
||||||
|
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
|
||||||
|
from ..plotting.matplot_dep import svgi_plots
|
||||||
|
svgi_plots.plot(self,*args,**kwargs)
|
||||||
|
|
||||||
if ax is None:
|
|
||||||
fig = pb.figure(num=fignum)
|
|
||||||
ax = fig.add_subplot(111)
|
|
||||||
|
|
||||||
#horrible hack here:
|
|
||||||
data = self.likelihood.data.copy()
|
|
||||||
self.likelihood.data = self.Y
|
|
||||||
GP.plot(self, ax=ax, **kwargs)
|
|
||||||
self.likelihood.data = data
|
|
||||||
|
|
||||||
Zu = self.Z * self._Xscale + self._Xoffset
|
|
||||||
if self.input_dim==1:
|
|
||||||
ax.plot(self.X_batch, self.likelihood.data, 'gx',mew=2)
|
|
||||||
if Z_height is None:
|
|
||||||
Z_height = ax.get_ylim()[0]
|
|
||||||
ax.plot(Zu, np.zeros_like(Zu) + Z_height, 'r|', mew=1.5, markersize=12)
|
|
||||||
|
|
||||||
if self.input_dim==2:
|
|
||||||
ax.scatter(self.X[:,0], self.X[:,1], 20., self.Y[:,0], linewidth=0, cmap=pb.cm.jet) # @UndefinedVariable
|
|
||||||
ax.plot(Zu[:,0], Zu[:,1], 'w^')
|
|
||||||
|
|
||||||
def plot_traces(self):
|
def plot_traces(self):
|
||||||
pb.figure()
|
"""
|
||||||
t = np.array(self._param_trace)
|
See GPy.plotting.matplot_dep.svgi_plots.plot_traces
|
||||||
pb.subplot(2,1,1)
|
"""
|
||||||
for l,ti in zip(self._get_param_names(),t.T):
|
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
|
||||||
if not l[:3]=='iip':
|
from ..plotting.matplot_dep import svgi_plots
|
||||||
pb.plot(ti,label=l)
|
svgi_plots.plot_traces(self)
|
||||||
pb.legend(loc=0)
|
|
||||||
|
|
||||||
pb.subplot(2,1,2)
|
|
||||||
pb.plot(np.asarray(self._ll_trace),label='stochastic likelihood')
|
|
||||||
pb.legend(loc=0)
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue